{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 安装类库\n",
    "# !mkdir /home/aistudio/external-libraries\n",
    "# !pip install imgaug -t /home/aistudio/external-libraries\n",
    "import sys\n",
    "sys.path.append('/home/aistudio/external-libraries')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image shape: (32, 32, 3)\n",
      "label value: cattle\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGylJREFUeJztnWlsXOd1ht9zZ+Um7pQoURIteZUVW04c17HjVFmcOGkAJ0VhJGgDA3WWAgnaoPljuECbAv2RAk2CoghSJKhrB0jjpHFSu47T2HGdOnYTWbIta7MsibQW7hSX4XAWznK//phhyuF7eDniSBQpnwcQxDm8c+937/DMved857yfOOdgGIaOd7kHYBhrGXMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAIwBzGMAGpyEBG5R0TeFJFTIvLgxRqUYawVZKUz6SISAnACwN0ABgDsB/Bp59yxpd7T0dHhent7V3Q8Yzn4c8zPzZEtlU6TrbFpg7rHcDhc+7BWgK/YisWCuu3cXJZsoTB/7+dylduNjYwjMZ2U5cZSyxW4DcAp51w/AIjIYwDuBbCkg/T29uLAgQM1HNJYkiI7w8jZPrLte/lVst31oXvUXba1d9Q+rmUoKrZ0ka3J2Un1/f19b5Cttb2BbGfPnqx4/eefe6iq8dXyiLUFwLkFrwfKtgpE5PMickBEDoyPj9dwOMNYfS55kO6c+45z7lbn3K2dnZ2X+nCGcVGp5RFrEMDWBa97yrYLwqqJLxxfeR6X/BTZkmP9ZHv+yZ/wdkl+jgeAP/nsZ9mofF6+r3yGylevAz/y55X3Dg2fJdvk9IA6xuFzR8nWf/I82RIzlddnLptS97eYWu4g+wFcIyJXiUgUwKcAPFnD/gxjzbHiO4hzriAiXwLwCwAhAA8759idDWMdU1Mezzn3NICnL9JYDGPNYTPphhHA5ZkJWgaRZedv3jZoKQxPlNmDYpLfm+G0eoOfI9vE8Ih67NGRUbKFhL9Tm1uayRaJRsjmK0G6czwtGOa3Il/MqGNs39hOttFxDtKH+4Yq95fPq/tbjN1BDCMAcxDDCMAcxDACMAcxjADWZJC+WmhVo87nor/CFAd9mcQsvzfKRXIbtmzWD64Eu6IErJ7Ps+Yzw+fIdvrIb8n21hvHeX9eVNkfz1wDwK+efpxsrZu3ku2OO+/iN4e5QnhiOkG2uVlOEGSzY2RzBU5CAMDYJFcLTE3z5+X8xde7ukSQ3UEMIwBzEMMIwBzEMAIwBzGMAN7WQTp8npE+f4oD27FXXiRbepIDzpEcf99ce9de9dDX3Hwr2bwIfxyHjx4m22vPP0+2pBK4z4zxTHgkHCNbdmKIbADw/M/OkO2G3/8I2d7zvg/yPud4xn5qjPfXv59L+UaHuBOyffs2dYxpn8vW82m+jlGvq+K1VPmnb3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIoKYsloicBpBESd6o4Jzj1MwaxmW5rGTiTc6gYHqGTG0hRcjM48xN/wvPqscOOy51iG/mTM33fvyfZDt64CDZdrRymUubx2NsUDJlxZDSgAGg/wRnt1488WOydffcSLa7bruBbOPH/5dsrz/zU7LNTbMARWpwlzrG+l3vYlsd63k1XdVa8Toaq04+4WKked/vnOPiF8O4ArBHLMMIoFYHcQCeEZFXROTz2gamrGisZ2p1kPc6594J4KMAvigi71u8gSkrGuuZWmV/Bsv/j4nIT1EStH7hgnZyGfUZvCj3RjR2cf/G+MBbZMuOs9JfQ5T7OWay+gke/61SvtK6nWzPPPMSb5fk3ogmr5ttrXGypeY4cD9+VhdtGEmxZMTABAfQ33/kX3m7g11kS59j4fKGIpeKxOq4HGYuxar0ALC9kQNyb+PVZMtK5Wcd0pQhFFZ8BxGRBhFpmv8ZwIcBHFnp/gxjLVLLHWQjgJ+WJXrCAP7NOfdfF2VUhrFGqEV6tB/AzRdxLIax5rA0r2EEcPn7QTTpwGoD96VWTqjy/U5ZYmzTO/immJ+dJlvf2TfJlp7kNHYuVqce+8QJXhkp1cjqgeE8n+TMBK+2lFBWVYpv58B9ZoqD7ENn9CB9PMdJjKZmVlE8e+p1su2b5CUVrungwDga4fObnmNbU5d+HYeHuA9mQ30bH6dtkQKjVLfsht1BDCMAcxDDCMAcxDACMAcxjAAue5CuxUpKJfgS772A9Q2VJRVEWR8vEuPZ5y233cn7UyZih1/lWe8eRYkQACbOs2DEoX2vka0uzIF7RxMHz3vv4jH+3s1cIv5P3/oW2ZIZLtMH9GuhKRymlVnu2FZelsB3HLiPjnErQbh1I9mkQS9Tev0otyckXmHhje4dOypep2b4uBp2BzGMAMxBDCMAcxDDCMAcxDACWPUgffGi85qH+krwnc1x/3hUmQkH9HX0PG16XQncC8r0fN8kdxRPKQHs3LW7yXbju+5Qx5g/y7PhP/rZL3m7DJeDf/KevWT7w49/mGwnT/HSAGMpTg7kXEgdY8TxttEwb9sU52vR0MJBdSLP59KwkWf7XR0vnTAwri9/UMxwEiOnaAg8/2RloXlymqsjNOwOYhgBmIMYRgDmIIYRgDmIYQSwbJAuIg8D+DiAMefc7rKtDcAPAfQCOA3gPucc11EvwncOc/nKWdu40hc+k+b1/17av49sGxob1ePccuNNZGuqqydbscj92YPjLJb2qxc5eH7rLK/rN6fMSMc296pjLCR5VnnsDC8PMJvka7Gzl2fnw+CAejrBwWrO5yC7UNRWawT8NAfGnuMSglCcP8OJSf5zGB3jZEedsq5jQzMnZBpbeDsAaFKSBnVhTrRs7WipeN13Tl/yYTHV3EEeAXDPItuDAJ5zzl0D4Lnya8O44ljWQZxzLwBYnJO8F8Cj5Z8fBfCJizwuw1gTrDQG2eicGy7/PIKSgIPKQuG48yYcZ6wzag7SnXMOSze/VgjHdZhwnLHOWOlM+qiIdDvnhkWkGwCv/K4gAsiioGpmloPQ/QdfJdvZ4UGyxaIsMAYAnW0sJnZd706yJWYmyHbwIAu6DZ8+RraRsxxwjk3xuRw8zIrmAHBbz/Vk27GJv0Cm2ri/urmDZ5/PDXFf+fAwB6KpJAfPLY16v3dqloP0mSmuANjR1UO2xjj/aaXrFGX5AidKiikeY9HTy9NzrVxWjzAnLJqbK88xHKru3rDSO8iTAO4v/3w/gCdWuB/DWNMs6yAi8gMAvwFwnYgMiMgDAL4G4G4ROQngQ+XXhnHFsewjlnPu00v8itf+NYwrDJtJN4wAVrXc3flAca4ygHpp38u03StHD5Ft5/UcCA6dS6jH+Y+nniPbxz+WJ1vfaRZv6zvHSu5eiMu5J5VZ4cGB02SLF9+tjvEdvb1k+7M//QzZtNnwnS0s3jY0xEmMk4c5uZCc4FR7c7sS6AIoFpQydmXSfUtrE9mcshyd+PzmkMcJ0FBIaUPI8+cHAGlF1C8U5pn9ol+ZDHDQqwcWY3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIYFWzWEW/iORsZebpv1/gXov2zVwqMpfl/okz/bpsvyiZkZcPserhESVbJsolCWmXKcw9C3s/uIdsXa1cKgIAhTRneXZfdx3ZPGW5goFfcJau7jxnc+5u4nUCN13LvTIHxofJBgDH67j3o7eHy1w6lbKSbJbLVLS+E9/n7JS2fmAsrJfD5JSelajS++NF9LKk5bA7iGEEYA5iGAGYgxhGAOYghhHAqgbp4gkiDZXBUnMbCy8MDrKk/aHXeQn2M6e4/wIAuns4oGvfxCUbvs+9CFOTvM+IEvT37lAC4M1ccpGZ00skclkO0ouK6EPmNJeQpE9zUJ1IcDBfp5SkvHsbl+x0x3jcALBhgvtJwq0snuBH+Dq6IgfaogTkxTwnX0SLpxWxidI+ufejMMf7jHqL329rFBpGzZiDGEYA5iCGEYA5iGEEsFJlxa8C+ByA+eaCh5xzTy+3r1Q6i32vVfZgFBXp/VCIh/VWP/dpDA7qQXpjK4sfFIutZEsmeW09LUi/Sglsuzo5SB8YOEG21rAusx+5kRMJ4QRL+Z87eJRsR2d4GYGfHePtEj4Hqy1xnmX+8HW3qmO8I8oKjudGT5Mt1MwBeaGeezrySvDsfE5MOJ8/fy3wBoBiUZmJd8qM/eKlMqpc33KlyooA8E3n3J7yv2WdwzDWIytVVjSMtwW1xCBfEpFDIvKwiPDzS5mFyoqJKlf1MYy1wkod5NsAdgLYA2AYwNeX2nChsmJzS8tSmxnGmmRFM+nOudH5n0XkuwCequZ9c7kM3jp9uHIAilR9VzuXu4vSZB+v02dXP/SBj5Dt+l07yFacYwXHrjZFOr97G9k623j2ecdWLlff1rlZHaMm7JcY4uUPJmZYtLIfHJg23cRl7IUMVw9MT7LQxRNnWNwBAG7s4tL2q7Rp7hFOLmSaeYbbFbhFoFDgIN3Pc9BfXGLmO53lpEq8QVlbsW7xuC/hTHpZbnSeTwLgOhDDuAKoJs37AwB7AXSIyACAvwGwV0T2oOSGpwF84RKO0TAuGytVVvyXSzAWw1hz2Ey6YQSwquXu0aiPzb2VAV1rB8/s5vMcuH3kD1ihcGKCg0MACMc5SMvleJ+33HIj2bIpDiSHlKUO9tzA793Zu51s0+d12f7hES4lnzw3QDbvat7nXe/fS7asx4HtzCxfnwJfGhx98zAbAZx98xTZukIc3G7wOIHifN7OE95OlJYDpwyysERMnVMUF8NFRZmxUHktnDLbrmF3EMMIwBzEMAIwBzGMAMxBDCOAVQ3Sk6kEXtj/8wpbQQnItvVyufqeO3aR7UyfLhznCQe7k7O8HqFf5Jn4ZIKDxokZDrRffp1npI/38ez64KAepMeV8u3rY7wMgdfAM/EjSln8S/t/TbaCEodGYlxmn5jVVx/ORfj6JOKcDAiHeLs0+PyKSv94aHEZOoCwYssraxkCgCf8HR8K83iyc5XJF19JIqj7r2orw3ibYg5iGAGYgxhGAOYghhHAqgbpsXgYO6+uDETzSrlz1yZtVphLwZMpvdExHOaS7HyR19tLJDmAzitTtm09nDSIxDhID8W5V3z79fp3kF9ke1OYg/xfv8jrKB49yWJyTU3cayOeorqe40qBiWn9OvqO3+8UtfqkokCfyXG/vwjPcEejvJ6gZsso6v4AEI7y34rn8bUtUILAgnTDqBlzEMMIwBzEMAIwBzGMAKrpKNwK4HsANqIU2XzHOfePItIG4IcAelHqKrzPOcfR2gIa6uK4dU9l3/asUpJ97NjrZJuc5l1fv2u3epymxg3amZBlbJwDtXyOt0tO8zJfMymefW5v26TYdMGX2Sx/N8VDHGiH6zlwL+b5mkWFVfLrG1mJ3VMSAdPj59QxtnT3kq01yn8yiUkWzPOFky+xGAffnhK4Fwpcwq61QABAg7LcWlEpIWhorFS69zxddJDGV8U2BQBfcc7tAnA7gC+KyC4ADwJ4zjl3DYDnyq8N44qiGuG4Yefcq+WfkwDeALAFwL0AHi1v9iiAT1yqQRrG5eKCYhAR6QVwC4B9ADY65+ZXchlB6RFMe8/vhOOmJ3mewDDWMlU7iIg0AngcwJedcxUzbM45hyVmXhYKx7W08TOxYaxlqnIQEYmg5Bzfd879pGwendfHKv/PCmeGsc6pJoslKMn8vOGc+8aCXz0J4H4AXyv//8Ry+yr6BSRmKwUQPHBZyEyCsxDHj3PW6FT//6jH6dnGyow37dlJtm3KdnUeZ8CcIgJQVPpYohHutRCuhAAA1Gf4httdz2O8ZQ9naTqaudzjpRdeIltiirWQtf6b8UH9u801cH9K8VoeI5TrowlnxMJ8MTIpLknxi9z7EY3r3+UhRXEzl1GUKRZXGlVXaVJVLdadAD4D4LCIHCzbHkLJMX4kIg8AOAPgvuoOaRjrh2qE416ENolQ4oMXdziGsbawmXTDCMAcxDACWNV+EE+A+milTzqfg6w7b38X2XbuvIFs/WdOq8cZG2fRhukJRSY/wgmC0QwnA1paOHBvauKSDRdRylRmuG8EANoaeN3Dzi7uO0lu5cB//29+Q7aJaVZ/9JVrqyHcKgMAaGvjX7Rt4XKYlPI1G1HEFKLachXC0XImw6U0ztOj6oKizKiddnrRPqu9NnYHMYwAzEEMIwBzEMMIwBzEMAJY1SAd4uCFKoMqL6LI6SsL03ds2kK2G3br6/9lsxzk+Yqq3/D5YbKNJTjYHZsZJdumbg6om5s5qPWX6DuYzfN300T2ZbINTrKwxJFjPGs+l+Vxx+NLRN+LaGjWA+CtbUrvR/Is2bwWPk5LhKsUfHBPhyqw4Pizmk3q1zHkKYG/sgAkTfYvNbO3CLuDGEYA5iCGEYA5iGEEYA5iGAGsapCezc3hxFDlunfNLTwjHctxYLohzs1WrcpsNgDEldJoDywY0NXK5dyRMM9czyR5dj3kOMqbmeby8tFxXnYBABKjrBR5qoPFKnqabyHbH9/3PrId3s/v1dZlbGllEYk5pUwfANw0VwEcOXaIbL2dLBjR3sAl+QVFCXNCKW3fEOHZeqeIOwDAbIIFNeL1/LdSv6FyjJ6nVzgsxu4ghhGAOYhhBGAOYhgBmIMYRgC1KCt+FcDnAMxHsA85554O2lfRL2J6tjIAzxZY1j6mLC2Qb2omW3J2KXU8LmWur+PArbG+m2zxKAecnc1c7p5X1A215RQGTg2pIwwrSxMcGmWFw3PKZPi1US79b1Ouz+YurjTwlPLwbL0eAE9EuFd9CzgxUhfmY9c1KIqQaT6ZfJFVFHNZXqIhn9PXKEwrypyxGB+7tbVS9TIUrk5jpJos1ryy4qsi0gTgFRF5tvy7bzrn/qGqIxnGOqSanvRhAMPln5MiMq+saBhXPLUoKwLAl0TkkIg8LCKqSvNCZcVUgm+nhrGWqUVZ8dsAdgLYg9Id5uva+xYqKzYoVbqGsZapaiZdU1Z0zo0u+P13ATy13H6ikTh6Nl5dYSsoUvWeUq6cyfCs8Ni0rvWrzXxv3c5LE6QVOf5skvfZ2KjMFLcrs/ARFnnbsV1f/6++kQPW/j4u3Y6FlSUMuvmatWzkRMLsLM8yh4ocAO+88WqyAYB/nMvO8wUedzymLEHg8RjbG3m7cITPeeo8Vx+Iz/oBAJDO8FNJOMbbeqHKP3VtvUSNZe8gSykrzsuOlvkkgCNVHdEw1hG1KCt+WkT2oJT6PQ3gC5dkhIZxGalFWTFwzsMwrgRsJt0wAljVcnfnisgVKoPgWIxLrRvquNy5WOCZ1HSClcEBoKGeA79ingPyyTSvexhX1uDTFNp9jwPYdI5n9rs2aeslAvX1HLBu2qSUiBf5OHM+zx63t3EPeCbB28UjnHAI1fN2ABAf54C8boTPx/M58C+Ckx1eiD/rugb+rNMpTshE4rrQW9FxQsYXDtwzhcoqB1/pe9ewO4hhBGAOYhgBmIMYRgDmIIYRwKoG6UW/iFS6cma54LNoWXKWhdpCwkGtCAe1ANDcxPZ0mvcZUZYEkzAH+KksB9/JIS5t12auoZwfADifM+chRR3e95VgV8m6F9PcIhAOcWCbSnNAnczpffPSzLP40sABfeo8B9V5JQgugI89l+HrmHccZA8MD6pjHBnjSoXOzZwMcOnKJE9RKfvXsDuIYQRgDmIYAZiDGEYA5iCGEYA5iGEEsLqlJr6HfKayVCE1y83z2kLyuRxnaaJKuQcATL3FJSgzKc6C7H7HtWRLjHBGxxO+TOoad0pm6q0+PfsSi3JWrqWNsy/Nrfwd1tzCZTPIcbYrrpSzJGZZJCOd5iwUALiMIvAQ4cxfHlx+4ucVgYYQfy75MGex0nnOTPWfZUELAEgm+G+gpYf7QQpe5Tk66NnFxdgdxDACMAcxjADMQQwjgGpabuMi8rKIvC4iR0Xkb8v2q0Rkn4icEpEfiojyYGwY65tqgvQ5AB9wzs2WxRteFJGfA/hLlITjHhORfwbwAEpKJ0uSz/kYGqgsx/CVwDYa4RKHwWEOnnM5XRAhrCxh0NLKgeTgsFLS4vF4PPD+6pW+Ck2VMRzTpY6OnzpOts1ZHmP4PJdnRCKcIGisZzXBhgZWPMxkOEgPRZfqteAAujHew9t5SsNMhktSpgp8vaWLy3MmZ/mzTs7qY8w6/o7vfScrT+6+ZXvF64OHn1H3t5hl7yCuxHwxUqT8zwH4AIAfl+2PAvhEVUc0jHVEVTGIiITKgg1jAJ4F0Adg2jk3nwccwBJqiwuF49KzejrRMNYqVTmIc67onNsDoAfAbQCur/YAC4Xj6hstTDHWFxeUxXLOTQN4HsB7ALSI/G4GrQeAPiNmGOuYapY/6ASQd85Ni0gdgLsB/D1KjvJHAB4DcD+AJ5bb19xcHn19w5X7V5YqaGpk28wU+3IyqT+y7drNsv+921kJcWDoNB+7iSWGXZ5nXesbOKCOKYF77zZdwa+tjWeas1meaZ5W1glMTClqlG3Kun557m3xPD5uInVeHWOuyLPz0wkWSdiQ4hn7mBI8Zz3eXyzK2yWSSh9LSv8ub97CTyXxTkW0o7EyOeGUXhmNarJY3QAeFZEQSnecHznnnhKRYwAeE5G/A/AaSuqLhnFFUY1w3CGUFN0X2/tRikcM44rFZtINIwBzEMMIQJyrruz3ohxMZBzAGQAdAPTIcP1h57I2We5ctjvnOpfbyao6yO8OKnLAOXfrqh/4EmDnsja5WOdij1iGEYA5iGEEcLkc5DuX6biXAjuXtclFOZfLEoMYxnrBHrEMIwBzEMMIYNUdRETuEZE3y626D6728WtBRB4WkTERObLA1iYiz4rIyfL/XO24BhGRrSLyvIgcK7dS/0XZvu7O51K2ha+qg5QLHr8F4KMAdqG0Uu6u1RxDjTwC4J5FtgcBPOecuwbAc+XX64ECgK8453YBuB3AF8ufxXo8n/m28JsB7AFwj4jcjlLV+Tedc1cDmEKpLfyCWO07yG0ATjnn+p1zOZRK5e9d5TGsGOfcCwAWN8Lfi1LLMbCOWo+dc8POuVfLPycBvIFSV+i6O59L2Ra+2g6yBcBCibwlW3XXERudc/NNLiMANl7OwawEEelFqWJ7H9bp+dTSFh6EBekXEVfKma+rvLmINAJ4HMCXnauUMVlP51NLW3gQq+0ggwC2Lnh9JbTqjopINwCU/2ex4TVKWcbpcQDfd879pGxet+cDXPy28NV2kP0ArilnF6IAPgXgyVUew8XmSZRajoEqW4/XAiIiKHWBvuGc+8aCX6278xGRThFpKf883xb+Bv6/LRxY6bk451b1H4CPATiB0jPiX6328Wsc+w8ADAPIo/RM+wCAdpSyPScB/BJA2+UeZ5Xn8l6UHp8OAThY/vex9Xg+AG5Cqe37EIAjAP66bN8B4GUApwD8O4DYhe7bSk0MIwAL0g0jAHMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAL4P/reBAlsXKWPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 读取数据\n",
    "reader = paddle.batch(\n",
    "    paddle.dataset.cifar.train100(),\n",
    "    batch_size=8) # 数据集读取器\n",
    "data = next(reader()) # 读取数据\n",
    "index = 0 # 批次索引\n",
    "\n",
    "# 读取图像\n",
    "image = np.array([x[0] for x in data]).astype(np.float32) # 读取图像数据，数据类型为float32\n",
    "image = image * 255 # 从[0,1]转换到[0,255]\n",
    "image = image[index].reshape((3, 32, 32)).transpose((1, 2, 0)).astype(np.uint8) # 数据格式从CHW转换为HWC，数据类型转换为uint8\n",
    "print('image shape:', image.shape)\n",
    "\n",
    "# 图像增强\n",
    "# sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "# seq = iaa.Sequential([\n",
    "#     sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "#     iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "# image = seq(image=image)\n",
    "\n",
    "# 读取标签\n",
    "label = np.array([x[1] for x in data]).astype(np.int64) # 读取标签数据，数据类型为int64\n",
    "vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "         'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "         'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "         'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "         'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "         'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "         'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "         'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "         'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "         'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "         'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "         'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "         'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "         'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "         'baby', 'boy', 'girl', 'man', 'woman',\n",
    "         'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "         'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "         'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "         'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "         'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "vlist.sort() # 字母上升排序\n",
    "print('label value:', vlist[label[index]])\n",
    "\n",
    "# 显示图像\n",
    "image = Image.fromarray(image)   # 转换图像格式\n",
    "image.save('./work/out/img.png') # 保存读取图像\n",
    "plt.figure(figsize=(3, 3))       # 设置显示大小\n",
    "plt.imshow(image)                # 设置显示图像\n",
    "plt.show()                       # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n",
      "valid_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 训练数据增强\n",
    "def train_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 增强图像\n",
    "    sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "    seq = iaa.Sequential([\n",
    "        sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "        iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "    images = seq(images=images)\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 验证数据增强\n",
    "def valid_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 读取训练数据\n",
    "train_reader = paddle.batch(\n",
    "    paddle.reader.shuffle(paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "train_data = next(train_reader()) # 读取训练数据\n",
    "\n",
    "train_image = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取训练图像\n",
    "train_image = train_augment(train_image)                                                       # 训练图像增强\n",
    "train_label = np.array([x[1] for x in train_data]).reshape((-1, 1)).astype(np.int64)           # 读取训练标签\n",
    "print('train_data: image shape {}, label shape:{}'.format(train_image.shape, train_label.shape))\n",
    "\n",
    "# 读取验证数据\n",
    "valid_reader = paddle.batch(\n",
    "    paddle.dataset.cifar.test100(),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "valid_data = next(valid_reader()) # 读取验证数据\n",
    "\n",
    "valid_image = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取验证图像\n",
    "valid_image = valid_augment(valid_image)                                                       # 验证图像增强\n",
    "valid_label = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64)           # 读取验证标签\n",
    "print('valid_data: image shape {}, label shape:{}'.format(valid_image.shape, valid_label.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 模型设计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm\n",
    "import math\n",
    "\n",
    "# 模组结构：输入维度，输出维度，滑动步长，基础长度, 队列长度\n",
    "group_arch = [(3, 128, 1, 2, 3), (512, 256, 2, 2, 3), (1024, 512, 2, 2, 3)]\n",
    "group_dim  = 2048 # 模组输出维度\n",
    "class_dim  = 100  # 类别数量维度\n",
    "\n",
    "# 卷积单元\n",
    "class ConvUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=3, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化卷积单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ConvUnit, self).__init__()\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=filter_size,\n",
    "            stride=stride,\n",
    "            padding=(filter_size-1)//2,                       # 输出特征图大小不变\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 投影单元\n",
    "class ProjUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=1, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化投影单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ProjUnit, self).__init__()\n",
    "        \n",
    "        # 添加池化\n",
    "        self.pool = Pool2D(\n",
    "            pool_size=filter_size,\n",
    "            pool_stride=stride,\n",
    "            pool_padding=0,\n",
    "            pool_type='avg')\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=1,\n",
    "            stride=1,\n",
    "            padding=0,\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行池化卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行池化\n",
    "        x = self.pool(x)\n",
    "        \n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 队列结构\n",
    "class SSRQueue(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=2, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化队列结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长，1保持不变，2下采样\n",
    "            queues  - 队列长度，分割尺度为2^(n-1)\n",
    "            act     - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRQueue, self).__init__()\n",
    "        \n",
    "        # 添加队列变量\n",
    "        self.queues = queues # 队列长度\n",
    "        self.split_list = [] # 分割列表\n",
    "        \n",
    "        # 添加队列列表\n",
    "        self.queue_list = [] # 队列列表\n",
    "        for i in range(queues):\n",
    "            # 添加队列项目\n",
    "            queue_item = self.add_sublayer( # 构造队列项目\n",
    "                'queue_' + str(i),\n",
    "                ConvUnit(\n",
    "                    in_dim=(in_dim if i==0 else out_dim), # 每组队列项目除第一个外，in_dim=out_dim\n",
    "                    out_dim=out_dim,\n",
    "                    filter_size=3,\n",
    "                    stride=(stride if i==0 else 1), # 每组队列项目除第一块外，stride=1\n",
    "                    act=act))\n",
    "            self.queue_list.append(queue_item) # 添加队列项目\n",
    "            \n",
    "            # 计算输出维度\n",
    "            if i < (queues-1): # 如果不是最后一项\n",
    "                out_dim = out_dim//2 # 输出维度减半\n",
    "                self.split_list.append(out_dim) # 添加分割列表\n",
    "            \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x_list = [] # 队列输出列表\n",
    "        for i, queue_item in enumerate(self.queue_list):\n",
    "            if i < (self.queues-1): # 如果不是最后一项\n",
    "                x = queue_item(x) # 提取队列特征\n",
    "                x_item, x = fluid.layers.split(input=x, num_or_sections=[-1, self.split_list[i]], dim=1)\n",
    "                x_list.append(x_item) # 添加输出列表\n",
    "            else: # 否则不对特征分割\n",
    "                x = queue_item(x) # 提取队列特征\n",
    "                x_list.append(x) # 添加输出列表\n",
    "        \n",
    "        # 联结特征\n",
    "        x = fluid.layers.concat(input=x_list, axis=1) # 队列输出列表按通道维进行特征联结\n",
    "        \n",
    "        return x\n",
    "    \n",
    "# 基础结构\n",
    "class SSRBasic(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=1, is_pass=True):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化基础结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            queues  - 队列长度\n",
    "            is_pass - 是否直连\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBasic, self).__init__()\n",
    "        \n",
    "        # 是否直连标识\n",
    "        self.is_pass = is_pass\n",
    "        \n",
    "        # 添加投影路径\n",
    "        self.proj = ProjUnit(in_dim=in_dim, out_dim=out_dim*4, filter_size=stride, stride=stride, act=None)\n",
    "        \n",
    "        # 添加卷积路径\n",
    "        self.con1 = ConvUnit(in_dim=in_dim, out_dim=out_dim, filter_size=1, stride=1, act='relu')\n",
    "        \n",
    "        if queues==1:\n",
    "            self.con2 = ConvUnit(in_dim=out_dim, out_dim=out_dim, filter_size=3, stride=stride, act='relu')\n",
    "        else:\n",
    "            self.con2 = SSRQueue(in_dim=out_dim, out_dim=out_dim, stride=stride, queues=queues, act='relu')\n",
    "        \n",
    "        self.con3 = ConvUnit(in_dim=out_dim, out_dim=out_dim*4, filter_size=1, stride=1, act=None)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "            y - 输出特征\n",
    "        \"\"\"\n",
    "        # 直连路径\n",
    "        if self.is_pass: # 是否直连\n",
    "            x_pass = x\n",
    "        else:            # 否则投影\n",
    "            x_pass = self.proj(x)\n",
    "        \n",
    "        # 卷积路径\n",
    "        x_con1 = self.con1(x)      # 特征降维\n",
    "        x_con2 = self.con2(x_con1) # 特征提取\n",
    "        x_con3 = self.con3(x_con2) # 特征升维\n",
    "        \n",
    "        # 输出特征\n",
    "        x = fluid.layers.elementwise_add(x=x_pass, y=x_con3, act='relu') # 直连路径与卷积路径进行特征相加\n",
    "        y = x\n",
    "        \n",
    "        return x, y\n",
    "    \n",
    "# 模块结构\n",
    "class SSRBlock(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, basics=1, queues=1):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模块结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            basics  - 基础长度\n",
    "            queues  - 队列长度\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBlock, self).__init__()\n",
    "        \n",
    "        # 添加模块列表\n",
    "        self.block_list = [] # 模块列表\n",
    "        for i in range(basics):\n",
    "            block_item = self.add_sublayer( # 构造模块项目\n",
    "                'block_' + str(i),\n",
    "                SSRBasic(\n",
    "                    in_dim=(in_dim if i==0 else out_dim*4), # 每组模块项目除第一块外，输入维度=输出维度\n",
    "                    out_dim=out_dim,\n",
    "                    stride=(stride if i==0 else 1), # 每组模块项目除第一块外，stride=1\n",
    "                    queues=queues,\n",
    "                    is_pass=(False if i==0 else True))) # 每组模块项目除第一块外，is_pass=True\n",
    "            self.block_list.append(block_item) # 添加模块项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模块输出列表\n",
    "        for block_item in self.block_list:\n",
    "            x, y_item = block_item(x) # 提取模块特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "\n",
    "# 模组结构\n",
    "class SSRGroup(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模组结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRGroup, self).__init__()\n",
    "        \n",
    "        # 添加模组列表\n",
    "        self.group_list = [] # 模组列表\n",
    "        for i, block_arch in enumerate(group_arch):\n",
    "            group_item = self.add_sublayer( # 构造模组项目\n",
    "                'group_' + str(i),\n",
    "                SSRBlock(\n",
    "                    in_dim=block_arch[0],\n",
    "                    out_dim=block_arch[1],\n",
    "                    stride=block_arch[2],\n",
    "                    basics=block_arch[3],\n",
    "                    queues=block_arch[4]))\n",
    "            self.group_list.append(group_item) # 添加模组项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模组输出列表\n",
    "        for group_item in self.group_list:\n",
    "            x, y_item = group_item(x) # 提取模组特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "        \n",
    "# 分割网络\n",
    "class SSRNet(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化分割网络，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRNet, self).__init__()\n",
    "        \n",
    "        # 添加模组结构\n",
    "        self.backbone = SSRGroup() # 输出：N*C*H*W\n",
    "        \n",
    "        # 添加全连接层\n",
    "        self.pool = Pool2D(global_pooling=True, pool_type='avg') # 输出：N*C*1*1\n",
    "        \n",
    "        stdv = 1.0/(math.sqrt(group_dim)*1.0)                    # 设置均匀分布权重方差\n",
    "        self.fc = Linear(                                        # 输出：=N*10\n",
    "            input_dim=group_dim,\n",
    "            output_dim=class_dim,\n",
    "            param_attr=fluid.initializer.Uniform(-stdv, stdv),   # 使用均匀分布初始权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),           # 使用常量数值初始偏置\n",
    "            act='softmax')\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入图像进行分类\n",
    "        输入:\n",
    "            x - 输入图像\n",
    "        输出:\n",
    "            x - 预测结果\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x, y_list = self.backbone(x)\n",
    "        \n",
    "        # 进行预测\n",
    "        x = self.pool(x)\n",
    "        x = fluid.layers.reshape(x, [x.shape[0], -1])\n",
    "        x = self.fc(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tatol param: 21322980\n",
      "infer shape: [1, 100]\n"
     ]
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.base import to_variable\n",
    "import numpy as np\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 输入数据\n",
    "    x = np.random.randn(1, 3, 32, 32).astype(np.float32)\n",
    "    x = to_variable(x)\n",
    "    \n",
    "    # 进行预测\n",
    "    backbone = SSRNet() # 设置网络\n",
    "    \n",
    "    infer = backbone(x) # 进行预测\n",
    "    \n",
    "    # 显示结果\n",
    "    parameters = 0\n",
    "    for p in backbone.parameters():\n",
    "        parameters += np.prod(p.shape) # 统计参数\n",
    "    \n",
    "    print('tatol param:', parameters)\n",
    "    print('infer shape:', infer.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd4VFX6wPHvOyWTRkgIoYYQmtJFQARxFWyL4uK64lpAxbKsvW1Rt6rrrq5uUX+6Kiru2kDsCHbFrmDoVQkBpCcESC9Tzu+PMwnpmYS0Sd7P88zDnTtn7n3v3PDec885914xxqCUUqrjcLR2AEoppVqWJn6llOpgNPErpVQHo4lfKaU6GE38SinVwWjiV0qpDkYTv1JKdTAhJ34RcYrIShFZVMNns0QkS0RWBV9XNW2YSimlmoqrAWVvAjYCcbV8/pIx5vojD0kppVRzCinxi0gyMBX4K3BrU6y4a9euJjU1tSkWpZRSHcby5cv3G2OSjmQZodb4HwR+C3Sqo8x5InIS8D1wizFmR9UCIjIbmA2QkpJCWlpaA8NVSqmOTUS2H+ky6m3jF5GzgUxjzPI6ir0FpBpjRgIfAP+rqZAxZo4xZqwxZmxS0hEdsJRSSjVSKJ27E4FpIrINmA+cIiLPVyxgjMk2xpQE3z4FjGnSKJVSSjWZehO/MeYOY0yyMSYVuBD42Bgzs2IZEelZ4e00bCewUkqpNqgho3oqEZG7gTRjzELgRhGZBviAA8CsxizT6/Wyc+dOiouLGxuWCoqMjCQ5ORm3293aoSil2hhprfvxjx071lTt3N26dSudOnUiMTEREWmVuNoDYwzZ2dnk5eXRr1+/1g5HKdWERGS5MWbskSyjTV25W1xcrEm/CYgIiYmJeuaklKpRm0r8gCb9JqK/o1KqNm0u8den2Otnb04xPn+gtUNRSqmwFHaJv8TrJzOvGG+g6fsmsrOzGTVqFKNGjaJHjx707t27/H1paWlIy7j88sv57rvvQl7nU089xc0339zYkJVSqsEaPaqntZQ1YTRHp3RiYiKrVq0C4M477yQ2NpZf//rXlcoYYzDG4HDUfMx85plnmjwupZRqSmFX4y9rum7JwUjp6ekMHTqUGTNmMGzYMPbs2cPs2bMZO3Ysw4YN4+677y4ve+KJJ7Jq1Sp8Ph/x8fHcfvvtHHPMMUyYMIHMzMw617N161YmT57MyJEjOf3009m5cycA8+fPZ/jw4RxzzDFMnjwZgLVr13LccccxatQoRo4cSUZGRvP9AEqpdqXN1vjvems9G3bnVpvvN4biUj+RbidOR8M6MIf2iuPPPxnWqHg2bdrEs88+y9ixdhTVfffdR5cuXfD5fEyePJnp06czdOjQSt/Jycnh5JNP5r777uPWW29l7ty53H777bWu49prr+Wqq65ixowZzJkzh5tvvplXXnmFu+66i08++YTu3btz6NAhAP7zn//w61//mgsuuICSkpJmOQNSSrVP4Vfjb6X1DhgwoDzpA8ybN4/Ro0czevRoNm7cyIYNG6p9JyoqijPPPBOAMWPGsG3btjrXsXTpUi688EIALr30Uj7//HMAJk6cyKWXXspTTz1FIGA7tU844QTuuece7r//fnbs2EFkZGRTbKZSqgNoszX+2mrmRaU+Nmfm0zcxhs5RLXdVakxMTPn05s2beeihh1i2bBnx8fHMnDmzxjHzERER5dNOpxOfz9eodT/55JMsXbqURYsWMXr0aFauXMkll1zChAkTWLx4MVOmTGHu3LmcdNJJjVq+UqpjCb8afzN27oYqNzeXTp06ERcXx549e3jvvfeaZLnjx49nwYIFADz//PPliTwjI4Px48fzl7/8hYSEBHbt2kVGRgYDBw7kpptu4uyzz2bNmjVNEoNSqv1rszX+2rRG525Vo0ePZujQoQwePJi+ffsyceLEJlnuo48+yhVXXMG9995L9+7dy0cI3XLLLWzduhVjDGeccQbDhw/nnnvuYd68ebjdbnr16sWdd97ZJDEopdq/NnWvno0bNzJkyJA6v1fqC7Bpby7JCVF0ifE0Z4hhL5TfUykVXtrdvXpCUTaQpxmu31JKqQ4hDBO/zfwBHb6olFKNEnaJXwQEIaBVfqWUapQwTPyCwwF+zftKKdUoISd+EXGKyEoRWVTDZx4ReUlE0kVkqYikNmWQVTlFa/xKKdVYDanx30Ttz9K9EjhojBkI/Bv4+5EGVheHQ7SNXymlGimkxC8iycBU4KlaipwD/C84/QpwqjTjk0CcIvibocY/efLkahdjPfjgg1xzzTV1fi82NhaA3bt3M3369BrLTJo0iarDV+uar5RSzSXUGv+DwG+B2p5+0hvYAWCM8QE5QOIRR1eL5qrxX3TRRcyfP7/SvPnz53PRRReF9P1evXrxyiuvNHlcSinVlOpN/CJyNpBpjFl+pCsTkdkikiYiaVlZWY1ejlOgOR7ANX36dBYvXlz+0JVt27axe/dufvSjH5Gfn8+pp57K6NGjGTFiBG+++Wa172/bto3hw4cDUFRUxIUXXsiQIUM499xzKSoqqnf98+bNY8SIEQwfPpzbbrsNAL/fz6xZsxg+fDgjRozg3//+NwAPP/wwQ4cOZeTIkeU3dlNKqVCEcsuGicA0ETkLiATiROR5Y8zMCmV2AX2AnSLiAjoD2VUXZIyZA8wBe+VunWt953bYu7bGj7r7/PgCBiIaeMeJHiPgzPtq/bhLly6MGzeOd955h3POOYf58+fz85//HBEhMjKS119/nbi4OPbv38/48eOZNm1arc+2feyxx4iOjmbjxo2sWbOG0aNH1xna7t27ue2221i+fDkJCQmcccYZvPHGG/Tp04ddu3axbt06gPLbMt93331s3boVj8dTPk8ppUJRb43fGHOHMSbZGJMKXAh8XCXpAywELgtOTw+WabbeVxHBmOa5iKtic0/FZh5jDL/73e8YOXIkp512Grt27WLfvn21Luezzz5j5kz7M40cOZKRI0fWud5vv/2WSZMmkZSUhMvlYsaMGXz22Wf079+fjIwMbrjhBt59913i4uLKlzljxgyef/55XK6wu+WSUqoVNTpjiMjdQJoxZiHwNPCciKQDB7AHiCNTR83cW+IjIyuf1MQY4pr41sznnHMOt9xyCytWrKCwsJAxY8YA8MILL5CVlcXy5ctxu92kpqbWeCvmppaQkMDq1at57733ePzxx1mwYAFz585l8eLFfPbZZ7z11lv89a9/Ze3atXoAUEqFpEEXcBljPjHGnB2c/lMw6WOMKTbGnG+MGWiMGWeMadbnAEa4bNilzdDQHxsby+TJk7niiisqderm5OTQrVs33G43S5YsYfv27XUu56STTuLFF18EYN26dfXeNnncuHF8+umn7N+/H7/fz7x58zj55JPZv38/gUCA8847j3vuuYcVK1YQCATYsWMHkydP5u9//zs5OTnk5+cf+cYrpTqEsKwiugr2MVAOkh/o3yzLv+iiizj33HMrjfCZMWMGP/nJTxgxYgRjx45l8ODBdS7jmmuu4fLLL2fIkCEMGTKk/MyhNj179uS+++5j8uTJGGOYOnUq55xzDqtXr+byyy8vf/LWvffei9/vZ+bMmeTk5GCM4cYbbyQ+Pv7IN1wp1SGE3W2ZAUzOTgL5+9nfaTDd4/SRg7XR2zIr1f50yNsyA4g4cGDQa3eVUqrhwjLxIw77JC69bYNSSjVYm0v8ITU9STBs0wxXcbUTrflMYqVU29amEn9kZCTZ2dn1J63gRVOiib9Gxhiys7OJjNT+D6VUdW1qVE9ycjI7d+6k3ts5lBZAYTa5bsOBmKiWCS7MREZGkpyc3NphKKXaoDaV+N1uN/369au/4LrX4L3L+b+jn+OGi6Y1f2BKKdWOtKmmnpC5bS1f/CWtHIhSSoWf8Ez8Ltt27fQXtnIgSikVfsIz8bujAXBojV8ppRosTBN/WY2/+W+SppRS7U14Jn6XbePXxK+UUg0XnonfXZb4talHKaUaKrwTf0ATv1JKNVR4Jv7gqB53QJt6lFKqoUJ52HqkiCwTkdUisl5E7qqhzCwRyRKRVcHXVc0TblCwxu/SNn6llGqwUK7cLQFOMcbki4gb+EJE3jHGfFOl3EvGmOubPsQaOJx4cetwTqWUaoR6E3/woellz/VzB1+tfuvHUvHg0qYepZRqsJDa+EXEKSKrgEzgA2PM0hqKnScia0TkFRHp06RR1sDr8OhwTqWUaoSQEr8xxm+MGQUkA+NEZHiVIm8BqcaYkcAHwP9qWo6IzBaRNBFJq/cOnPXwOTy4tKlHKaUarEGjeowxh4AlwJQq87ONMWVZ+CmgxieLG2PmGGPGGmPGJiUlNSbecl6HR0f1KKVUI4QyqidJROKD01HA6cCmKmV6Vng7DdjYlEHWxO+MxGW0xq+UUg0VyqiensD/RMSJPVAsMMYsEpG7gTRjzELgRhGZBviAA8Cs5gq4TMAZSYTRGr9SSjVUKKN61gDH1jD/TxWm7wDuaNrQ6uZ3RuIxOS25SqWUahfC88pdwO/04MFLINDqI0uVUiqshG3iDzgjiaIEb0AfuK6UUg0Rtonf74wiUkrx+bXGr5RSDRG2iT/giiQSTfxKKdVQYZv4jdNDJKW8vnJna4eilFJhJWwTv88ZhUd83P3WutYORSmlwkrYJn6vIwIAD6WtHIlSSoWX8E384gEgShO/Uko1SNgm/lKxT+GK1MSvlFINEraJv6ypJ1I08SulVEOEbeIvxdb4talHKaUaJnwTv3buKqVUo4Rt4ndE2Aeua1OPUko1TNgm/imj+gPQPVLv1aOUUg0RtonfExUDaI1fKaUaKmwTPy7buesO6FO4lFKqIcI38buj7T+a+JVSqkFCeeZupIgsE5HVIrJeRO6qoYxHRF4SkXQRWSoiqc0RbCXuYI1fn7urlFINEkqNvwQ4xRhzDDAKmCIi46uUuRI4aIwZCPwb+HvThlkDlx3VExHQ5+4qpVRD1Jv4jZUffOsOvqreBP8c4H/B6VeAU0VEmizKmjjdBHASQSnG6D35lVIqVCG18YuIU0RWAZnAB8aYpVWK9AZ2ABhjfEAOkFjDcmaLSJqIpGVlZR1Z5CL4gvfk9+rDWJRSKmQhJX5jjN8YMwpIBsaJyPDGrMwYM8cYM9YYMzYpKakxi6jE7/AQRSk+fe6uUkqFrEGjeowxh4AlwJQqH+0C+gCIiAvoDGQ3RYB18TkjiRSt8SulVEOEMqonSUTig9NRwOnApirFFgKXBaenAx+bFmh49zsj8VCKz681fqWUCpUrhDI9gf+JiBN7oFhgjFkkIncDacaYhcDTwHMikg4cAC5stogrCDjLmnq0xq+UUqGqN/EbY9YAx9Yw/08VpouB85s2tPr5nVFEUkqpT2v8SikVqvC9chcwLtvGrzV+pZQKXVgn/oAzyjb1aBu/UkqFLKwTv3HZzl0d1aOUUqEL+8Rvm3q0xq+UUqEK78TvjiKKEq3xK6VUA4R14scVSRSleLWNXymlQhbeiT8ihigpZemW/a0diVJKhY2wTvwuTywAC7/d3MqRKKVU+AjrxN+9axcAphzduZUjUUqp8BHWib/s8YsOX0ErB6KUUuEjvBN/RFniL2rlQJRSKnyEeeK3bfwOryZ+pZQKVXgn/mBTj8tf2MqBKKVU+AjvxB9s6nH6tcavlFKhCu/E744BwOXTGr9SSoUqvBN/sMbv1hq/UkqFLJRHL/YRkSUiskFE1ovITTWUmSQiOSKyKvj6U03LanLlbfzFLbI6pZRqD0J59KIP+JUxZoWIdAKWi8gHxpgNVcp9bow5u+lDrEOEbepxB7TGr5RSoaq3xm+M2WOMWRGczgM2Ar2bO7CQOCPw49CmHqWUaoAGtfGLSCr2+btLa/h4goisFpF3RGRYE8QWSkAUSxQRAW3qUUqpUIXS1AOAiMQCrwI3G2Nyq3y8AuhrjMkXkbOAN4BBNSxjNjAbICUlpdFBV1TqiCTCaI1fKaVCFVKNX0Tc2KT/gjHmtaqfG2NyjTH5wem3AbeIdK2h3BxjzFhjzNikpKQjDN0qkUgitKlHKaVCFsqoHgGeBjYaY/5VS5kewXKIyLjgcrObMtDaeJ1RuLWpRymlQhZKU89E4BJgrYisCs77HZACYIx5HJgOXCMiPqAIuNAY0yLPQ/Q5o4go1Rq/UkqFqt7Eb4z5ApB6yjwCPNJUQTWE3xlFhDnQGqtWSqmwFN5X7gIBdzSRgWJa6ARDKaXCXtgnfuOKJooSvkzPZnu2PpBFKaXqE/aJP+COIUpKmPn0Us586PPWDkcppdq8sE/8EhFNDHZUT2Gpv5WjUUqpti/sEz+RcURJKW58rR2JUkqFhfBP/FFdAOiMtu8rpVQowj7xByLjAegs+Uidg06VUkpBO0j8+Y5OAMSTj0Mzv1JK1SvsE//gfn0BSJB8HJr3lVKqXmGf+OMS7M3e4iUf0Rq/UkrVK+wTP9Flnbv5dd9XQimlFNAeEr8nDp9xEC8F2savlFIhCP/EL0KuxBBPPkVeP8VevYhLKaXqEv6JH+gUn0S85APw4wc/a+VolFKqbWsXid8dm8hRcfbK3e3Zha0cjVJKtW3tIvET1YWYQF5rR6GUUmGhnST+BKL9mviVUioUoTxzt4+ILBGRDSKyXkRuqqGMiMjDIpIuImtEZHTzhFuLqASifTktukqllApXodT4fcCvjDFDgfHAdSIytEqZM4FBwdds4LEmjbI+0YlEBgqIwAugT+NSSqk61Jv4jTF7jDErgtN5wEagd5Vi5wDPGusbIF5EejZ5tLWJ7QZAIrkA+AKa+JVSqjYNauMXkVTgWGBplY96AzsqvN9J9YMDIjJbRNJEJC0rK6thkdYltjsASXIIAJ9fE79SStUm5MQvIrHAq8DNxpjcxqzMGDPHGDPWGDM2KSmpMYuoWbDG3y2Y+L2BQNMtWyml2pmQEr+IuLFJ/wVjzGs1FNkF9KnwPjk4r2V06gFojV8ppUIRyqgeAZ4GNhpj/lVLsYXApcHRPeOBHGPMniaMs24x9uwhCTuyx+fXGr9SStXGFUKZicAlwFoRWRWc9zsgBcAY8zjwNnAWkA4UApc3fah1cLrJc3Qur/F7tXNXKaVqVW/iN8Z8AXXf8djY8ZPXNVVQjZHjTCBJtMavlFL1aR9X7gK5ri50k4MAeLWNXymlatVuEn+/1AH0cNjBRqU+rfErpVRtQmnjDwtRCT3xOA7hIEBOkbe1w1FKqTar3dT4ST4OR6CUSY5VHCwsbe1olFKqzWo/if/oM/HHdOMi58fsPlTEjgN6X36llKpJ+0n8TjeMvIBTHCt5YPFqfnT/Er1Zm1JK1aD9JH7A2XMkTjEki70PkF/H8yulVDXtKvET3xeAPsHE/0X6/taMRiml2qT2lfgTbOIvq/HPeubb1oxGKaXapPaV+GO64ZMI+khma0eilFJtVvtK/A4HeZE9y5t6lFJKVde+Ej9wyNOrvKlHKaVUde0u8ed4Ktf4f8jW8fxKKVVRu0v8+6IGkiD5XOZ8D4Dzn/iqlSNSSqm2pd0l/rTEn/CBfzR/dD1HLIXsz9fbNyilVEXtLvEndY7hdf+JuCRAsuwn2u1s7ZCUUqpNCeXRi3NFJFNE1tXy+SQRyRGRVcHXn5o+zNBdMbEfib0HANBbssgr8fHp99rZq5RSZUKp8f8XmFJPmc+NMaOCr7uPPKzGczkdDB0yHIBksVfuXjZ3WWuGpJRSbUq9id8Y8xlwoAViaTK5jniKjZveordsUEqpqpqqjX+CiKwWkXdEZFgTLbPRojwudpmu9Nbx/EopVU1TJP4VQF9jzDHA/wFv1FZQRGaLSJqIpGVlNV9SvmhcCkXRvbTGr5RSNTjixG+MyTXG5Aen3wbcItK1lrJzjDFjjTFjk5KSjnTVtXI7HQwfOpxBEWHVQqWUUi3iiBO/iPQQEQlOjwsuM/tIl3vE4lOI8R3iZMfq1o5EKaXalFCGc84DvgaOFpGdInKliFwtIlcHi0wH1onIauBh4ELTFh59NepiMqMGMMf9TzqT39rRKKVUm+Gqr4Ax5qJ6Pn8EeKTJImoqcb34rM/VTP/+N6TK3taORiml2ox2d+VuRYVRvQH7RC6vP8C4v37I22v3tHJUSinVutp14vfGpQA28ecUecnMK+FPb9Z4AbJSSnUY7TrxR8V2Jtt0oo9kUuoLAODTB7ArpTq4dp34B3WPZYdJIlmyymv6fr8mfqVUx9auE/9xqV3YQzdOcq5l2s5/EEe+1viVUh1eu078AI5Yey3ZNO+7zIn4Nybgb+WIlFKqdbX7xJ/uHgzA+46TGO/YyIDAttYNSCmlWlm94/jD3RfRp/D0voEkSi5neD5jkOxs7ZCUUqpVtfvEb8TBAeLIM9F4jZOjHDsxxhC8y4RSSnU47b6pp6wv14uLraYHIyQD/8oX4fv3IXMTPHIcpH/YukEqpVQLavc1fn+FUTzfm2TOdi6FhdfaGYPPhv3fw/PnwY0roUv/6gvI2wedurdQtEop1fzafY3/D1OHlE9nmXgAfMnj7YxNi8EZYafXvlr9y5kb4V+DYcPC5g5TKaVaTLtP/MemJJRPL/KPZ69JIPfHD7Il0BMwMGQa9BkPG96s/uXv3wMTgDUv1bxwvw/mXQyba2kq+uYx2Jl25BuxeyVsWXLky2lr1r0K279qufXtWw/PTIUnT4HSgpZbr1JtTLtP/BUtN0czvuRRpjy7i2UBO8yT3mNg6Dmwby28NBO+fdomh9evgdXzbZn0D+Hbp6A4t/IC96yG7xbDC+fB+jegoMJjCAqy4d3b4c3rIBCoPajSAnsAAfjhG/uqqCAbXjgfXrmi8nL2roNVL1Yu+9Uj8K9hUFoY+o/SWrzF8Ob18M5tLbfO5f+F7V/AruWwY2nLrVepNqZDJf4ymXklfB0Yat/0GQfHXAjDz4MflsLiW+HQD7B6HmRthF7Hgq8YFv8KlvwVDu2AuWfClo9h+5eHF/ryZfDgcEh7xr7f9rn9N2sTPH06rHuteiDGwH+nwn+Ot+t8bTa8eAEUVnhy2Lu3QUEWFB2ApY/ByuftAeC1X8Ab18Dz0+GFn8PGRfDlQ5C7E75+BLZ+dngZBzLggz/DmgXw2i/t2cNzP7NNWWX8XvvvvvXw8LGwP716vL4SGzPAvg3wypX24BeK16+Gta8cfr/tC/AWwt41cHBb9fJblkD6R6EtO1TpH0LfE0Echw+yT50GRYeadj1KtXHtvnMXYNYJqfz3q22V5i0KTCCrNJ5f+/vznwVbeOTiJ4mi2Cb0gadBxiew4DL48d9swlv5HKTNtQlp/3ew8EZIHABdBsBVH0J2OnxyLyy6GYzfJsaIWHtg2ZkGH98D/U6CT++HXqNg1MV2/u6VgMAzZ0HODhvc4l/Byb+1iW/tyzD2Crvu935nk1beHsjcYNe99TOIToSXZtjvRna2ByiAa76G7kPh/T/CpkWHN37tAtuENTcNrvnKnjl8/i+45DXbPHUgA5Y+bg82g6fC0VPtAXH1PIjtDinj7YEm4IXszXDirbZZbMBkGHE+VB0qm7vbfjd3N4yYbud9/67tX/GX2gPSyb89XD5zkz0AOlxw0yqI7VZ5ecZUX0d9srfY7Rp/LZTm2wPPp3+3n23/Cgaf1bDlNQVvETjc4Kzw33DLEnBH2d9YqWYi9T0sS0TmAmcDmcaY4TV8LsBDwFlAITDLGLOivhWPHTvWpKU1Qft3iK57cQWL11S/F3/fxGi2Zxfy+MwxTBneo/KHAT84nHb60A8wZxJExMComfDJ3+z8Y2fCOY/aab8X5s+ALR+BJw6Sx8KMl2HVPHjjavB0hpIcQKDrICjYbw8qk38H7//eLuO4X8C3Tx6OoddouOI9mHOyTfZlybLr0XD1Fzb5OiNsoj6wFY6aAsuesMms6yCI6gIZS+DYS6D7cLsd3zwKZ/0D3r0DOifDwa02AXXqATk7gSp/E73Hwq40GHO5PftI/9D2jSSPhXeCCdsVBb4iGPdLewaVNNjG9e2TkJ8JX/wL3DFw+w/2DOrhYyH5OCjOsc0v42bDWQ/Y90//GPL3QkkepJ4I/SfZEVdDptn1ZW2Cma9XTpgVrXgWvn4UrnjXNs9997Y9yK57xY7eWjrHnj2V+dGv4dQ/1v0H1NR8JfCf8ZA8Dn72BGz+APL22oM+Bi5+CQaccvgMS687UUEistwYM/ZIlhFKjf+/2CdsPVvL52cCg4Kv44HHgv+2KY5a/uP4gnfrjHDV8HlZ0geIT4HfZhx+32MEfP4PGHnh4XlON/xsDiy41Ca7sVfa+UPPse39ETEw81X7vdICcEXaRD1utk3cnXrA1H/AyJ/bM4juw6DbULvciTfZppn4PrZWePaD4IoAgqOSTrjhcBxHnQGLf22TbnyKfZ1+N0R3sYlk/NV2XuZGSHsahp1ra+ovzYSBp9oDxJcP2vkleTbRDzwNfvKgXX5ZjTsQsGcgXfrbs5n3/2gT6rInYMTPbVNOxTMNbwHcm2yHxxZk2m3qOdJ+b9kcexD6/n17FjHzVduM9OkD9uwLoOcxh5uWvnrY1t43vw+b34PT/wIluRDRyfbTZG2yTWeZmyDnB/udSb+zsQ4/zzbFjb3Ctvvv/Lbyfg8EIHO9/R0aknA3LYa+J0BUQvXPAn57JrZ7pV2/32vPQA5kwNFnwqtX2YO4Jw469bT9Hzcst/P9pTD1n/Y7iQNCj0epWtRb4wcQkVRgUS01/ieAT4wx84LvvwMmGWPqfNRVS9f4b5q/kjdX7a42P6mTh6y8Ep69YhwnHZXUfAEcyIDIeJt8a3Jwu23a6Ny7adbnLYI9a2xTU23Jq/AArHrBJsCIGPsdd5Ttx3j5Mvjp4xARbTtgT/0TJB1d9zoDAdu/sHMZbHzLzhv2M1j/GnQfYTvQwSa3wWfDucFad8Bvm3bSPwBx2oNnWZOQ32vjWvWCbe6K7W63Z+tnkDjQdhLnVrgNR9mZR7dhNnlHxsN5T9mDao8R1WNedKs9+HU9ysbV/2R77caq5+3BZOKNtlzRIXsQGz4d3JEVttkf7K8ognkXwPjrYMrfqq/nvd/b3yY+xZ51OVzQbYidLs6xB6yTf2sP9iLw3LmVD3RlBp6toV7hAAAXFUlEQVRmz6qOOqPufaHarZaq8denN7CjwvudwXlt6hmHzlqSX2GJHVHjcjbzqXRNF4dVlNC3adfnjoKUek68ortUPlNwR9l/4/vALz4+PP/CF0Jbp8NhE2Vxjq1pD50Gp/zRHswGnAIf3W2T2Vn/tGcK5d9zwsUL7AEipqtt2injdNvX+GvsC+yopXWvwsLr7ftJd0DAB3G9bbt9XhFcNA+i4sHpqZyoq+p7gk38To9dzxcP2j6aTr1svEUHbf9OQZYtX5IHfY63fSkLb7BnZvn7gODfz7pX4OgpdlCAp5Odt+xJm/TH/dL2GX35oD07Oe4qu7y0ubZJa9i5trwx0O9k2LHMHkg8sfZg3Lm3HWk2/yLbdOfpZM/ktBlINVBT1PgXAfcZY74Ivv8IuM0YU606LyKzgdkAKSkpY7Zv335EwTfErQtW8dqKXVwyvi/PfVN9vc9feTwnDuraYvG0ezV1wDZ1e/VHf7HDMi9daA86YA84e9fY5rJQ49yz2h6QROxZUOZGWxt//jzYvQJ6jrKdrUsft30mRQfsgcsdY5N8t6G26anvxMNNW8N+BqffBSX58MSPbE39whcrNx/WJRAATPXyRYfgiZPgUPBvePozMPxnISzPD4XZ1TvKa5K5yR6cTrvTnimpNqUpavwdpqlnaUY2F8z5hhd/cTwXP1l9DPczlx/H5KND+E+hOg5vkW23P/os2+T1+T/tWUB8ij0jmHSbPZMBm6j9JbbJyl8KP3wNiE3crki4YUXT3fojZ5c9A1l4Pexda5u8jp0Jx19j1/39e7bvovgQTPs/SEi1/U7bv7JnZEdNsf0Xnli7vAMZtm8l4LVJf8tHduRY/8kw4TrbX7X1M9uPkzjAdph7OtnrIRJS7bUwJbl22uU5HGcgYA/CCX1tZ7bfa8/oys4swS6r6KDtD/IW2unOfezZW+4u+wr47ZnVvvX2LK7HSHvgFbG/cVS8vRYma6PdZyV59gDpjoTEQXa0XFS8PWi7PMHh0YfsMlweu4zszRDTzcbnL7VxHNoB0Qk2RmeE3e/Gb9dRV96ssWIj9oytOAdSToBBpzV697eVxD8VuB47qud44GFjzLj6ltnSib9MYamPoX96r9r8OZeM4YxhWrtRdTj0Azx5Kvz0sbr/45YW2uG18Sm2Y/r4a2yHelPLz7LDfb9723ZWO1y2yQsgoZ9NtAVZNuHk7ILUiXa4MgACcb3sIIPiCtcxxHSz/Sj9TrKjv8q4Y2znfF0iOkFknE3y3kIbT0nu4X6XisuKSbQJtbgB11DEdLPl/aV1l3NH2w72okM2ZqfHHpQrcrgBc/j36tzHnhF5gxc/RsTawQZFB22znq/EHkAcbnvgqthUWUkt+TTgt7+FOG1fzqTbQ93qalqkjV9E5gGTgK4ishP4M+AGMMY8DryNTfrp2OGclx9JQM0tOqLmTd6wJ1cTv6pbfAr8ZnP95SKi4ZLXmz+e2CSYcK19bf3c1tQjYiBlgq1VFh2Al2fZoawXv2Svs8jdbTv996y2F855Yu0oouE/s8kupkJz55hZtvbsirDNWXvX2iHInlibJHuMtDXy3SsPnwH4Sm15d7StGfcYYdcV19s2MxXut1ejF+6360voa69DcUfbV2ScrdmX5NnEG9fb9r3EJEGXfrbpLG+PrXGbYHNYfqb9t9dou8yyZr/SQhtn5+TDtXhvkV1W2dlO2eCByDi7TG+RPQuoqUmuMdePVFR00B70XBGNX0YTCanG3xxaq8ZfJvX2xdXm/eSYXry1ejdb/nYWTod2mKl2IBCwNc2o+NaORDWRpqjxd8hbNgB8/KuTq817a7Ud7llQ6mvpcJRqHg6HJn1VTYdN/P2TYrn19KNq/KyoVB/IrpRqvzps4ge4bvLAGufnl2iNXynVfnXoxO90CM9feTwpXaIrzT9QUMpjn2zhtRX6YHalVPvTYTt3K0rPzOe0f31a42fb7pvawtEopVTttHO3iXhctf8M+3KLWzASpZRqfpr4qTvxH/+3j2itsyKllGoOmviBiDoSP8C+3JI6P1dKqXCiiR/wuOq+cdbGvbncuXA9BTraRynVDmjiB6IinHz8q5NZf9ePa/z8nkUb+O9X2/jf19vw+gOs2nFIm3+UUmFLE39Q/6RYYjw138dnS5a9OdWOA4X8dfFGfvrolzzxWUaNZZVSqq3TxF/Fi1cdz9/OreFJTcC8ZTvKH9p+3zub2HWoqMZySinVlmnir+KEgV05ukdsSGW3ZtVzm1qllGqDNPHXIDHGU38hYMfBQtIz85n9bBqZeXa8/+5DRcx6Zhmpty/mzVW7QlrOm6t2sfKHg42OVymlGkKv3K3F3pxiusd56HfH2w36XqzHVe1eP2P6JnD5xFROHdydqIjKI4iMMeXrmDk+hbU7c3jz+hOPLHilVLulV+42ox6dIxERFl4/kT9MHQLA6JR4ju/XhWsnDaj1ezXd4G359oNc/+JK/v3h97y49AcCAcO2/QWk3r6YcX/7qLzc89/8wOqdOfgDhrRtBwgEDMYYMoNXDz//zXZSb19MRlY+AN/vy2NvTjHb9hcw5i8fkJ6ZV2tc9R3gn/lyK79+eXWdZWpT7PWTevtinvt6W6O+r5RqWVrjD9Gmvbn0jo+iU6QbgMvmLuPT77MavbzRKfGs+KHux87NOiGVYb3i+M0ra3j1mhM477Gvyj+7+uQBPP7pFgD6d40hY38B549J5i8/Hc6/P/yeDbtzuWvaMD7elMmlE1I56g/vAHDG0O6cNaIn54zqhQSfJpSdX8KYez4E4IHpIxnfP5E1O3OYOrInALnFXk7756c8eMEoXE4HR/fohAj89uU1XHZCKskJUfzo/iX07BzJ13ecyvmPf8V5o5O5cFxKo36bghIfh4q89I6PYufBQnx+Q2rXmEYtS6n2piWfuTsFeAhwAk8ZY+6r8vks4AGgrFH7EWPMU3UtM9wSf1WlvgDFPnvffmPgmLveL//shAGJfLUlu0nWk5oYzbbswpDKDupmO6U3Z+aHVL5zlJs7pw3llpdqrulv+ssUIt1OPt60jyv+e3hfnTiwK+ce25tfBc8QhvWKY/3uXCLdDv527ghuXWDnl93gLiuvhM2ZeZwwoGv1lQDPfbOdCf0TGRiM/8yHPmfjnly23Te1/ElpTXGzPGMMC1fvJju/lCtO7HfEy1OqNbTUM3edwKPA6cBO4FsRWWiM2VCl6EvGmOuPJJhwEuFyVLrVQ1li2pKVT2piDE6HkFPkZf3uHEp8Ad5Zu4cFaTXf5vnGUwfx8Ec1P8s11KQPoSf8MjlF3lqTPsDgP75b4/wv0vfzRfr+8vfrd+cCUOwNlCd9sAfHCJeDi5/8hs2Z+aT/9UxcTkf5uh9dks4bK3eRmWdvibH13rOY/+0ONu6xy3t77Z5K6zXGcNdbGxjTN4GfHNOrQdvq9Qf4zcureWOVfcravtxiusRE8Pnm/Tx/1fEAXP/iCrLzS5k3ezxZeSXszy9hSM+4Bq1HqXBQb+IHxgHpxpgMABGZD5wDVE38ChiQdHgoaOcod3ktd/LR3bh/+jEcLChlzucZfLF5P+eM6sXjn2Zw7aQBnDakG798bjl7cmx7fiePi7wSH+ePSSa1awzpmfn8YeoQ4qMjmHDvR+XJssyffzKUu96qvEtOOiqJrfvz2XGgda43mP/tDwztGVd+QHrmy2389e2NgP1tcoq8lcpX7Ui/9oUV5dOPLknngfe+A+C/X23j2JR4khOiyc4v4cS/L+GZy48jt8jLx5syue+8keXfy8orIcLl4JfPpfFNxoHy+RUvwJv9bBrTRvVi0ZrDB5qfPfYlOw4UsfXes8qbxJRqL+pt6hGR6cAUY8xVwfeXAMdXrN0Hm3ruBbKA74FbjDE7aljWbGA2QEpKypjt27c30WZ0PPklPmI9Lr7bm0e/rjH4AgHOffQrth8o4Ph+iTxw/ki6dYokEDB8uHEfJb4AN8xbyevXnsC8ZT+wIG0nDoEot5OCUj8zx6fw/Dc/ALZv4a3Vu8kp8uIL2L+PIT3jymviZf7y0+FcdFwfHl2yhSKvv7zPoaVMGdaDUwZ347evrqk0f9LRSXyVnk2pP9Co5d582iAe/NCegb16zQm4HMJHmzKZcXwKcZHuaiOzQrVqxyGWbz/IlTU0MwUChm8yspkwIFEPNKpOLdLGH2LiTwTyjTElIvJL4AJjzCl1LTfc2/jDXU6Rl1iPC6dD2La/gL6J0ew8WERyQlSlxPPmql2s353LLacdhcflwOEQTrp/CTEeF+/c9KNKyzxUWMqsZ77lhAGJvLlqd61XNv9h6hBeW7GLiQMTufrkAfx54fpKtW2A4b3juPGUQbyyfCfvb9hXPj/C5aDU17iEDvB/Fx3LgrQdfL55f/2FazHvF+NZuHo3V56Yyusrd9E11kN0hJMLjkvBHxyJlVvsIz7KjcMhFHv9LNt6gEvnLgNgYLdY/jB1CGNTu3DXwvW8vLx6E+Cfzh7K+WOTeX/9Pvbnl/DUF1s5YUAiN546iG6dPOWDDBat2Y3H5eT0od0bvT0qvLRU4p8A3GmM+XHw/R0Axph7aynvBA4YYzrXtVxN/OGr2Gs7tSPd9dd89+QU0a1TJF+m7+fZr7fzxCVjcDpqr9EGAobvM/MY3ONw23p2fgklvgC94qPK5336fRaXBRNpmQn9E9mba6+/+G5vHgcLbVNSl5gIrp00gOljkomPjgBsE9B1L6zA43ZwXGoXotxOHv0knUOFlZufGiLC6ah0lpHUyUNesZdib+MPVHUZ1SeeVTvsyLBnZh1HQkwEPxwoJL/Yx7Kt2Yzum0B0hIu4SBffZBxgdN94BnXrRO+EKCKcDnyBANn5pRR5/RR7/fTrGoPL4SDS7WBbdiHREU4SYyJwOoQSXwCHCOmZ+RSU+hiTkoCjhv1ojKlUcSj2+nl5+U5+PjYZj8tJic9PRlZBed9JTpGXzlHuOrezLEdVXG4gYHA4hIMFpXjcDvwBg9vpoKxI2R13q8ZTVX6Jjyi3E4dUX36xz090RCit4bXbn19CfJS7/G++Kc7mWirxu7DNN6diR+18C1xsjFlfoUxPY8ye4PS5wG3GmPF1LVcTvzpSBwtKSYiJqPVzYwwlvkBIB6iKDhSU0iUmgm+3HcDlELZkFbA9u4AzhvbgoY82s+NAId/tq37NxMSBiWRkFZT309QkNTGafbklFAUPnr/58dE88N53JES7OffYZNKz8vns+ywigp3gnSJdZBeUNij+IyFiR6lV5HJIeZNfRXGRNikWef10jnKTX+Kj1Bcg1uOi1B/A7XSQV3z4upbEmIjybYlw2sER+SU+unXyEDAGYyg/cAo2SYpAYamfTh4XHpeDYl+AolI/Xn+AqAhnpeVHuh34/PYA0CUmgoAxZBeU0iU6Aq8/QLTHSYk3QMBAiddPsc+P12+3KyHaTZTbid8YXA4bV26xHVJcVV0psyx+YwxOh5Q/y8PlEKIjnHSN9eALGC4+PoWrT679eqC6tMioHmOMT0SuB97DDueca4xZLyJ3A2nGmIXAjSIyDfABB4BZRxKUUqGoK+mDTRwNTfpgzxAAjkvtAsCxKQnlnz11Wf3/3wIBgwGcDiG/xEeJ10+XmIhqtb2y2uh1kwc2KD5jDKX+ACt/OFSeMFftOERiTAT7ckvoHuchK6+EnvFRBAKGtO0H6BLjYVC3WNwuB7sPFeEPGPwBQ8AY9uWW4HE56BobQUmwGW3r/gJ6xUdRVOrH5RRcDsHjcvJNRjYHCks5aVBSeSy5xT6iIpyU+gJ0inRhzOEmuVJ/gE82ZTJpcDfAngEszTjAiQO74nHbA4PLYRO80+Go9DQ8Y+zvWOoLUOIL4HTYONxOB5l5xfSIi2RfbglOp5CRVUCk20FCMMkndfIgCB63A68vgNvloLDER36Jj7hIN52j3eWjzrLySvAHDDEeF06xB7mDhaVk5ZUwqHvN9+0Saq65i4BTBKdTKCzx8f6GfUzon0hclJsIp4NCrx+XQ0hOqH5AaUl6AZdSSoURvWWDUkqpBtPEr5RSHYwmfqWU6mA08SulVAejiV8ppToYTfxKKdXBaOJXSqkORhO/Ukp1MK12AZeIZAGNvT1nV6Dxd9lqe3R72jbdnrato21PX2NM0pGsoNUS/5EQkbQjvXKtLdHtadt0e9o23Z6G06YepZTqYDTxK6VUBxOuiX9OawfQxHR72jbdnrZNt6eBwrKNXymlVOOFa41fKaVUI4Vd4heRKSLynYiki8jtrR1PGRHpIyJLRGSDiKwXkZuC87uIyAcisjn4b0JwvojIw8HtWCMioyss67Jg+c0iclmF+WNEZG3wOw9LCzyVW0ScIrJSRBYF3/cTkaXBGF4SkYjgfE/wfXrw89QKy7gjOP87Eflxhfktui9FJF5EXhGRTSKyUUQmhPP+EZFbgn9r60RknohEhtP+EZG5IpIpIusqzGv2/VHbOpppex4I/r2tEZHXRSS+wmcN+t0bs29rZYwJmxf2CWBbgP5ABLAaGNracQVj6wmMDk53wj6ucihwP3B7cP7twN+D02cB72CfMjceWBqc3wXICP6bEJxOCH62LFhWgt89swW261bgRWBR8P0C4MLg9OPANcHpa4HHg9MXAi8Fp4cG95MH6Bfcf87W2JfA/4CrgtMRQHy47h+gN7AViKqwX2aF0/4BTgJGA+sqzGv2/VHbOpppe84AXMHpv1fYngb/7g3dt3XG2pz/0Zrhj30C8F6F93cAd7R2XLXE+iZwOvAd0DM4ryfwXXD6CeCiCuW/C35+EfBEhflPBOf1BDZVmF+pXDNtQzLwEXAKsCj4H2h/hT/k8v2BfTTnhOC0K1hOqu6jsnItvS+BzthEKVXmh+X+wSb+HdiE5wrunx+H2/4BUqmcKJt9f9S2jubYniqfnQu8UNPvWd/v3pj/e3XFGW5NPWV/7GV2Bue1KcFTrWOBpUB3E3wQPbAX6B6crm1b6pq/s4b5zelB4LdAIPg+EThkjCl7wnXFGMrjDn6eEyzf0O1sLv2ALOAZsU1XT4lIDGG6f4wxu4B/AD8Ae7C/93LCd/+UaYn9Uds6mtsV2DMPaPj2NOb/Xq3CLfG3eSISC7wK3GyMya34mbGH5LAYRiUiZwOZxpjlrR1LE3FhT8MfM8YcCxRgT/PLhdn+SQDOwR7QegExwJRWDaqJtcT+aKl9LiK/B3zAC829rlCEW+LfBfSp8D45OK9NEBE3Num/YIx5LTh7n4j0DH7eE8gMzq9tW+qan1zD/OYyEZgmItuA+djmnoeAeBFx1RBDedzBzzsD2TR8O5vLTmCnMWZp8P0r2ANBuO6f04CtxpgsY4wXeA27z8J1/5Rpif1R2zqahYjMAs4GZgQPNNDw7cmm4fu2ds3VBtkcL2ytLQNbyynr+BjW2nEFYxPgWeDBKvMfoHJH0v3B6alU7qxaFpzfBdsWnRB8bQW6BD+r2ll1Vgtt2yQOd+6+TOUOpmuD09dRuYNpQXB6GJU7sTKwHVgtvi+Bz4Gjg9N3BvdNWO4f4HhgPRAdXN//gBvCbf9QvY2/2fdHbetopu2ZAmwAkqqUa/Dv3tB9W2eczfkfrZn+4M/CjpjZAvy+teOpENeJ2FPGNcCq4OssbFvbR8Bm4MMKf5QCPBrcjrXA2ArLugJID74urzB/LLAu+J1HqKcDpwm3bRKHE3//4H+o9OAfoic4PzL4Pj34ef8K3/99MObvqDDSpaX3JTAKSAvuozeCiSJs9w9wF7ApuM7ngkkkbPYPMA/bP+HFnpFd2RL7o7Z1NNP2pGPb38tywuON/d0bs29re+mVu0op1cGEWxu/UkqpI6SJXymlOhhN/Eop1cFo4ldKqQ5GE79SSnUwmviVUqqD0cSvlFIdjCZ+pZTqYP4frGaBsdybf0wAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "complete - train time: 25017s, best epoch: 201, best loss: 0.929860, best accuracy: 79.00%\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.utils.plot import Ploter\n",
    "import numpy as np\n",
    "import time\n",
    "import math\n",
    "import os\n",
    "\n",
    "epoch_num = 300   # 训练周期，取值一般为[1,300]\n",
    "train_batch = 128 # 训练批次，取值一般为[1,256]\n",
    "valid_batch = 128 # 验证批次，取值一般为[1,256]\n",
    "displays = 100    # 显示迭代\n",
    "\n",
    "start_lr = 0.00001                         # 开始学习率，取值一般为[1e-8,5e-1]\n",
    "based_lr = 0.1                             # 基础学习率，取值一般为[1e-8,5e-1]\n",
    "epoch_iters = math.ceil(50000/train_batch) # 每轮迭代数\n",
    "warmup_iter = 10 * epoch_iters             # 预热迭代数，取值一般为[1,10]\n",
    "\n",
    "momentum = 0.9     # 优化器动量\n",
    "l2_decay = 0.00005 # 正则化系数，取值一般为[1e-5,5e-4]\n",
    "epsilon = 0.05     # 标签平滑率，取值一般为[1e-2,1e-1]\n",
    "\n",
    "checkpoint = False                   # 断点标识\n",
    "model_path = './work/out/ssrnet'     # 模型路径\n",
    "result_txt = './work/out/result.txt' # 结果文件\n",
    "class_num  = 100                     # 类别数量\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 准备数据\n",
    "    train_reader = paddle.batch(\n",
    "        reader=paddle.reader.shuffle(reader=paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "        batch_size=train_batch)\n",
    "    \n",
    "    valid_reader = paddle.batch(\n",
    "        reader=paddle.dataset.cifar.test100(),\n",
    "        batch_size=valid_batch)\n",
    "    \n",
    "    # 声明模型\n",
    "    model = SSRNet()\n",
    "    \n",
    "    # 优化算法\n",
    "    consine_lr = fluid.layers.cosine_decay(based_lr, epoch_iters, epoch_num) # 余弦衰减策略\n",
    "    decayed_lr = fluid.layers.linear_lr_warmup(consine_lr, warmup_iter, start_lr, based_lr) # 线性预热策略\n",
    "    \n",
    "    optimizer = fluid.optimizer.Momentum(\n",
    "        learning_rate=decayed_lr,                           # 衰减学习策略\n",
    "        momentum=momentum,                                  # 优化动量系数\n",
    "        regularization=fluid.regularizer.L2Decay(l2_decay), # 正则衰减系数\n",
    "        parameter_list=model.parameters())\n",
    "    \n",
    "    # 加载断点\n",
    "    if checkpoint: # 是否加载断点文件\n",
    "        model_dict, optimizer_dict = fluid.load_dygraph(model_path) # 加载断点参数\n",
    "        model.set_dict(model_dict)                                  # 设置权重参数\n",
    "        optimizer.set_dict(optimizer_dict)                          # 设置优化参数\n",
    "    else:          # 否则删除结果文件\n",
    "        if os.path.exists(result_txt): # 如果存在结果文件\n",
    "            os.remove(result_txt)      # 那么删除结果文件\n",
    "    \n",
    "    # 初始训练\n",
    "    avg_train_loss = 0 # 平均训练损失\n",
    "    avg_valid_loss = 0 # 平均验证损失\n",
    "    avg_valid_accu = 0 # 平均验证精度\n",
    "    \n",
    "    iterator = 1                                # 迭代次数\n",
    "    train_prompt = \"Train loss\"                 # 训练标签\n",
    "    valid_prompt = \"Valid loss\"                 # 验证标签\n",
    "    ploter = Ploter(train_prompt, valid_prompt) # 训练图像\n",
    "    \n",
    "    best_epoch = 0           # 最好周期\n",
    "    best_accu = 0            # 最好精度\n",
    "    best_loss = 100.0        # 最好损失\n",
    "    train_time = time.time() # 训练时间\n",
    "    \n",
    "    # 开始训练\n",
    "    for epoch_id in range(epoch_num):\n",
    "        # 训练模型\n",
    "        model.train() # 设置训练\n",
    "        for batch_id, train_data in enumerate(train_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = train_augment(image_data)                                                        # 使用数据增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "\n",
    "            label_data = np.array([x[1] for x in train_data]).astype(np.int64)                        # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                             # 转换数据类型\n",
    "            label = fluid.layers.label_smooth(label=fluid.one_hot(label, class_num), epsilon=epsilon) # 使用标签平滑\n",
    "            label.stop_gradient = True                                                                # 停止梯度传播\n",
    "\n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label, soft_label=True)\n",
    "            train_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            # 反向传播\n",
    "            train_loss.backward()\n",
    "            optimizer.minimize(train_loss)\n",
    "            model.clear_gradients()\n",
    "            \n",
    "            # 显示结果\n",
    "            if iterator % displays == 0:\n",
    "                # 显示图像\n",
    "                avg_train_loss = train_loss.numpy()[0]                # 设置训练损失\n",
    "                ploter.append(train_prompt, iterator, avg_train_loss) # 添加训练图像\n",
    "                ploter.plot()                                         # 显示训练图像\n",
    "                \n",
    "                # 打印结果\n",
    "                print(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\".format(\n",
    "                    iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "                \n",
    "                # 写入文件\n",
    "                with open(result_txt, 'a') as file:\n",
    "                    file.write(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\\n\".format(\n",
    "                        iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "            \n",
    "            # 增加迭代\n",
    "            iterator += 1\n",
    "            \n",
    "        # 验证模型\n",
    "        valid_loss_list = [] # 验证损失列表\n",
    "        valid_accu_list = [] # 验证精度列表\n",
    "        \n",
    "        model.eval()   # 设置验证\n",
    "        for batch_id, valid_data in enumerate(valid_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = valid_augment(image_data)                                                        # 使用图像增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "            \n",
    "            label_data = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64) # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                       # 转换数据类型\n",
    "            label.stop_gradient = True                                                          # 停止梯度传播\n",
    "            \n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算精度\n",
    "            valid_accu = fluid.layers.accuracy(infer,label)\n",
    "            \n",
    "            valid_accu_list.append(valid_accu.numpy())\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label)\n",
    "            valid_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            valid_loss_list.append(valid_loss.numpy())\n",
    "        \n",
    "        # 设置结果\n",
    "        avg_valid_accu = np.mean(valid_accu_list)             # 设置验证精度\n",
    "        \n",
    "        avg_valid_loss = np.mean(valid_loss_list)             # 设置验证损失\n",
    "        ploter.append(valid_prompt, iterator, avg_valid_loss) # 添加训练图像\n",
    "        \n",
    "        # 保存模型\n",
    "        fluid.save_dygraph(model.state_dict(), model_path)     # 保存权重参数\n",
    "        fluid.save_dygraph(optimizer.state_dict(), model_path) # 保存优化参数\n",
    "        \n",
    "        if avg_valid_loss < best_loss:\n",
    "            fluid.save_dygraph(model.state_dict(), model_path + '-best') # 保存权重\n",
    "            \n",
    "            best_epoch = epoch_id + 1                                    # 更新迭代\n",
    "            best_accu = avg_valid_accu                                   # 更新精度\n",
    "            best_loss = avg_valid_loss                                   # 更新损失\n",
    "    \n",
    "    # 显示结果\n",
    "    train_time = time.time() - train_time # 设置训练时间\n",
    "    print('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}'.format(\n",
    "        train_time, best_epoch, best_loss, best_accu))\n",
    "    \n",
    "    # 写入文件\n",
    "    with open(result_txt, 'a') as file:\n",
    "        file.write('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}\\n'.format(\n",
    "            train_time, best_epoch, best_loss, best_accu))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "infer time: 0.013995s, infer value: cattle\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGylJREFUeJztnWlsXOd1ht9zZ+Um7pQoURIteZUVW04c17HjVFmcOGkAJ0VhJGgDA3WWAgnaoPljuECbAv2RAk2CoghSJKhrB0jjpHFSu47T2HGdOnYTWbIta7MsibQW7hSX4XAWznK//phhyuF7eDniSBQpnwcQxDm8c+937/DMved857yfOOdgGIaOd7kHYBhrGXMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAIwBzGMAGpyEBG5R0TeFJFTIvLgxRqUYawVZKUz6SISAnACwN0ABgDsB/Bp59yxpd7T0dHhent7V3Q8Yzn4c8zPzZEtlU6TrbFpg7rHcDhc+7BWgK/YisWCuu3cXJZsoTB/7+dylduNjYwjMZ2U5cZSyxW4DcAp51w/AIjIYwDuBbCkg/T29uLAgQM1HNJYkiI7w8jZPrLte/lVst31oXvUXba1d9Q+rmUoKrZ0ka3J2Un1/f19b5Cttb2BbGfPnqx4/eefe6iq8dXyiLUFwLkFrwfKtgpE5PMickBEDoyPj9dwOMNYfS55kO6c+45z7lbn3K2dnZ2X+nCGcVGp5RFrEMDWBa97yrYLwqqJLxxfeR6X/BTZkmP9ZHv+yZ/wdkl+jgeAP/nsZ9mofF6+r3yGylevAz/y55X3Dg2fJdvk9IA6xuFzR8nWf/I82RIzlddnLptS97eYWu4g+wFcIyJXiUgUwKcAPFnD/gxjzbHiO4hzriAiXwLwCwAhAA8759idDWMdU1Mezzn3NICnL9JYDGPNYTPphhHA5ZkJWgaRZedv3jZoKQxPlNmDYpLfm+G0eoOfI9vE8Ih67NGRUbKFhL9Tm1uayRaJRsjmK0G6czwtGOa3Il/MqGNs39hOttFxDtKH+4Yq95fPq/tbjN1BDCMAcxDDCMAcxDACMAcxjADWZJC+WmhVo87nor/CFAd9mcQsvzfKRXIbtmzWD64Eu6IErJ7Ps+Yzw+fIdvrIb8n21hvHeX9eVNkfz1wDwK+efpxsrZu3ku2OO+/iN4e5QnhiOkG2uVlOEGSzY2RzBU5CAMDYJFcLTE3z5+X8xde7ukSQ3UEMIwBzEMMIwBzEMAIwBzGMAN7WQTp8npE+f4oD27FXXiRbepIDzpEcf99ce9de9dDX3Hwr2bwIfxyHjx4m22vPP0+2pBK4z4zxTHgkHCNbdmKIbADw/M/OkO2G3/8I2d7zvg/yPud4xn5qjPfXv59L+UaHuBOyffs2dYxpn8vW82m+jlGvq+K1VPmnb3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIoKYsloicBpBESd6o4Jzj1MwaxmW5rGTiTc6gYHqGTG0hRcjM48xN/wvPqscOOy51iG/mTM33fvyfZDt64CDZdrRymUubx2NsUDJlxZDSgAGg/wRnt1488WOydffcSLa7bruBbOPH/5dsrz/zU7LNTbMARWpwlzrG+l3vYlsd63k1XdVa8Toaq04+4WKked/vnOPiF8O4ArBHLMMIoFYHcQCeEZFXROTz2gamrGisZ2p1kPc6594J4KMAvigi71u8gSkrGuuZWmV/Bsv/j4nIT1EStH7hgnZyGfUZvCj3RjR2cf/G+MBbZMuOs9JfQ5T7OWay+gke/61SvtK6nWzPPPMSb5fk3ogmr5ttrXGypeY4cD9+VhdtGEmxZMTABAfQ33/kX3m7g11kS59j4fKGIpeKxOq4HGYuxar0ALC9kQNyb+PVZMtK5Wcd0pQhFFZ8BxGRBhFpmv8ZwIcBHFnp/gxjLVLLHWQjgJ+WJXrCAP7NOfdfF2VUhrFGqEV6tB/AzRdxLIax5rA0r2EEcPn7QTTpwGoD96VWTqjy/U5ZYmzTO/immJ+dJlvf2TfJlp7kNHYuVqce+8QJXhkp1cjqgeE8n+TMBK+2lFBWVYpv58B9ZoqD7ENn9CB9PMdJjKZmVlE8e+p1su2b5CUVrungwDga4fObnmNbU5d+HYeHuA9mQ30bH6dtkQKjVLfsht1BDCMAcxDDCMAcxDACMAcxjAAue5CuxUpKJfgS772A9Q2VJRVEWR8vEuPZ5y233cn7UyZih1/lWe8eRYkQACbOs2DEoX2vka0uzIF7RxMHz3vv4jH+3s1cIv5P3/oW2ZIZLtMH9GuhKRymlVnu2FZelsB3HLiPjnErQbh1I9mkQS9Tev0otyckXmHhje4dOypep2b4uBp2BzGMAMxBDCMAcxDDCMAcxDACWPUgffGi85qH+krwnc1x/3hUmQkH9HX0PG16XQncC8r0fN8kdxRPKQHs3LW7yXbju+5Qx5g/y7PhP/rZL3m7DJeDf/KevWT7w49/mGwnT/HSAGMpTg7kXEgdY8TxttEwb9sU52vR0MJBdSLP59KwkWf7XR0vnTAwri9/UMxwEiOnaAg8/2RloXlymqsjNOwOYhgBmIMYRgDmIIYRgDmIYQSwbJAuIg8D+DiAMefc7rKtDcAPAfQCOA3gPucc11EvwncOc/nKWdu40hc+k+b1/17av49sGxob1ePccuNNZGuqqydbscj92YPjLJb2qxc5eH7rLK/rN6fMSMc296pjLCR5VnnsDC8PMJvka7Gzl2fnw+CAejrBwWrO5yC7UNRWawT8NAfGnuMSglCcP8OJSf5zGB3jZEedsq5jQzMnZBpbeDsAaFKSBnVhTrRs7WipeN13Tl/yYTHV3EEeAXDPItuDAJ5zzl0D4Lnya8O44ljWQZxzLwBYnJO8F8Cj5Z8fBfCJizwuw1gTrDQG2eicGy7/PIKSgIPKQuG48yYcZ6wzag7SnXMOSze/VgjHdZhwnLHOWOlM+qiIdDvnhkWkGwCv/K4gAsiioGpmloPQ/QdfJdvZ4UGyxaIsMAYAnW0sJnZd706yJWYmyHbwIAu6DZ8+RraRsxxwjk3xuRw8zIrmAHBbz/Vk27GJv0Cm2ri/urmDZ5/PDXFf+fAwB6KpJAfPLY16v3dqloP0mSmuANjR1UO2xjj/aaXrFGX5AidKiikeY9HTy9NzrVxWjzAnLJqbK88xHKru3rDSO8iTAO4v/3w/gCdWuB/DWNMs6yAi8gMAvwFwnYgMiMgDAL4G4G4ROQngQ+XXhnHFsewjlnPu00v8itf+NYwrDJtJN4wAVrXc3flAca4ygHpp38u03StHD5Ft5/UcCA6dS6jH+Y+nniPbxz+WJ1vfaRZv6zvHSu5eiMu5J5VZ4cGB02SLF9+tjvEdvb1k+7M//QzZtNnwnS0s3jY0xEmMk4c5uZCc4FR7c7sS6AIoFpQydmXSfUtrE9mcshyd+PzmkMcJ0FBIaUPI8+cHAGlF1C8U5pn9ol+ZDHDQqwcWY3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIYFWzWEW/iORsZebpv1/gXov2zVwqMpfl/okz/bpsvyiZkZcPserhESVbJsolCWmXKcw9C3s/uIdsXa1cKgIAhTRneXZfdx3ZPGW5goFfcJau7jxnc+5u4nUCN13LvTIHxofJBgDH67j3o7eHy1w6lbKSbJbLVLS+E9/n7JS2fmAsrJfD5JSelajS++NF9LKk5bA7iGEEYA5iGAGYgxhGAOYghhHAqgbp4gkiDZXBUnMbCy8MDrKk/aHXeQn2M6e4/wIAuns4oGvfxCUbvs+9CFOTvM+IEvT37lAC4M1ccpGZ00skclkO0ouK6EPmNJeQpE9zUJ1IcDBfp5SkvHsbl+x0x3jcALBhgvtJwq0snuBH+Dq6IgfaogTkxTwnX0SLpxWxidI+ufejMMf7jHqL329rFBpGzZiDGEYA5iCGEYA5iGEEsFJlxa8C+ByA+eaCh5xzTy+3r1Q6i32vVfZgFBXp/VCIh/VWP/dpDA7qQXpjK4sfFIutZEsmeW09LUi/Sglsuzo5SB8YOEG21rAusx+5kRMJ4QRL+Z87eJRsR2d4GYGfHePtEj4Hqy1xnmX+8HW3qmO8I8oKjudGT5Mt1MwBeaGeezrySvDsfE5MOJ8/fy3wBoBiUZmJd8qM/eKlMqpc33KlyooA8E3n3J7yv2WdwzDWIytVVjSMtwW1xCBfEpFDIvKwiPDzS5mFyoqJKlf1MYy1wkod5NsAdgLYA2AYwNeX2nChsmJzS8tSmxnGmmRFM+nOudH5n0XkuwCequZ9c7kM3jp9uHIAilR9VzuXu4vSZB+v02dXP/SBj5Dt+l07yFacYwXHrjZFOr97G9k623j2ecdWLlff1rlZHaMm7JcY4uUPJmZYtLIfHJg23cRl7IUMVw9MT7LQxRNnWNwBAG7s4tL2q7Rp7hFOLmSaeYbbFbhFoFDgIN3Pc9BfXGLmO53lpEq8QVlbsW7xuC/hTHpZbnSeTwLgOhDDuAKoJs37AwB7AXSIyACAvwGwV0T2oOSGpwF84RKO0TAuGytVVvyXSzAWw1hz2Ey6YQSwquXu0aiPzb2VAV1rB8/s5vMcuH3kD1ihcGKCg0MACMc5SMvleJ+33HIj2bIpDiSHlKUO9tzA793Zu51s0+d12f7hES4lnzw3QDbvat7nXe/fS7asx4HtzCxfnwJfGhx98zAbAZx98xTZukIc3G7wOIHifN7OE95OlJYDpwyysERMnVMUF8NFRZmxUHktnDLbrmF3EMMIwBzEMAIwBzGMAMxBDCOAVQ3Sk6kEXtj/8wpbQQnItvVyufqeO3aR7UyfLhznCQe7k7O8HqFf5Jn4ZIKDxokZDrRffp1npI/38ez64KAepMeV8u3rY7wMgdfAM/EjSln8S/t/TbaCEodGYlxmn5jVVx/ORfj6JOKcDAiHeLs0+PyKSv94aHEZOoCwYssraxkCgCf8HR8K83iyc5XJF19JIqj7r2orw3ibYg5iGAGYgxhGAOYghhHAqgbpsXgYO6+uDETzSrlz1yZtVphLwZMpvdExHOaS7HyR19tLJDmAzitTtm09nDSIxDhID8W5V3z79fp3kF9ke1OYg/xfv8jrKB49yWJyTU3cayOeorqe40qBiWn9OvqO3+8UtfqkokCfyXG/vwjPcEejvJ6gZsso6v4AEI7y34rn8bUtUILAgnTDqBlzEMMIwBzEMAIwBzGMAKrpKNwK4HsANqIU2XzHOfePItIG4IcAelHqKrzPOcfR2gIa6uK4dU9l3/asUpJ97NjrZJuc5l1fv2u3epymxg3amZBlbJwDtXyOt0tO8zJfMymefW5v26TYdMGX2Sx/N8VDHGiH6zlwL+b5mkWFVfLrG1mJ3VMSAdPj59QxtnT3kq01yn8yiUkWzPOFky+xGAffnhK4Fwpcwq61QABAg7LcWlEpIWhorFS69zxddJDGV8U2BQBfcc7tAnA7gC+KyC4ADwJ4zjl3DYDnyq8N44qiGuG4Yefcq+WfkwDeALAFwL0AHi1v9iiAT1yqQRrG5eKCYhAR6QVwC4B9ADY65+ZXchlB6RFMe8/vhOOmJ3mewDDWMlU7iIg0AngcwJedcxUzbM45hyVmXhYKx7W08TOxYaxlqnIQEYmg5Bzfd879pGwendfHKv/PCmeGsc6pJoslKMn8vOGc+8aCXz0J4H4AXyv//8Ry+yr6BSRmKwUQPHBZyEyCsxDHj3PW6FT//6jH6dnGyow37dlJtm3KdnUeZ8CcIgJQVPpYohHutRCuhAAA1Gf4httdz2O8ZQ9naTqaudzjpRdeIltiirWQtf6b8UH9u801cH9K8VoeI5TrowlnxMJ8MTIpLknxi9z7EY3r3+UhRXEzl1GUKRZXGlVXaVJVLdadAD4D4LCIHCzbHkLJMX4kIg8AOAPgvuoOaRjrh2qE416ENolQ4oMXdziGsbawmXTDCMAcxDACWNV+EE+A+milTzqfg6w7b38X2XbuvIFs/WdOq8cZG2fRhukJRSY/wgmC0QwnA1paOHBvauKSDRdRylRmuG8EANoaeN3Dzi7uO0lu5cB//29+Q7aJaVZ/9JVrqyHcKgMAaGvjX7Rt4XKYlPI1G1HEFKLachXC0XImw6U0ztOj6oKizKiddnrRPqu9NnYHMYwAzEEMIwBzEMMIwBzEMAJY1SAd4uCFKoMqL6LI6SsL03ds2kK2G3br6/9lsxzk+Yqq3/D5YbKNJTjYHZsZJdumbg6om5s5qPWX6DuYzfN300T2ZbINTrKwxJFjPGs+l+Vxx+NLRN+LaGjWA+CtbUrvR/Is2bwWPk5LhKsUfHBPhyqw4Pizmk3q1zHkKYG/sgAkTfYvNbO3CLuDGEYA5iCGEYA5iGEEYA5iGAGsapCezc3hxFDlunfNLTwjHctxYLohzs1WrcpsNgDEldJoDywY0NXK5dyRMM9czyR5dj3kOMqbmeby8tFxXnYBABKjrBR5qoPFKnqabyHbH9/3PrId3s/v1dZlbGllEYk5pUwfANw0VwEcOXaIbL2dLBjR3sAl+QVFCXNCKW3fEOHZeqeIOwDAbIIFNeL1/LdSv6FyjJ6nVzgsxu4ghhGAOYhhBGAOYhgBmIMYRgC1KCt+FcDnAMxHsA85554O2lfRL2J6tjIAzxZY1j6mLC2Qb2omW3J2KXU8LmWur+PArbG+m2zxKAecnc1c7p5X1A215RQGTg2pIwwrSxMcGmWFw3PKZPi1US79b1Ouz+YurjTwlPLwbL0eAE9EuFd9CzgxUhfmY9c1KIqQaT6ZfJFVFHNZXqIhn9PXKEwrypyxGB+7tbVS9TIUrk5jpJos1ryy4qsi0gTgFRF5tvy7bzrn/qGqIxnGOqSanvRhAMPln5MiMq+saBhXPLUoKwLAl0TkkIg8LCKqSvNCZcVUgm+nhrGWqUVZ8dsAdgLYg9Id5uva+xYqKzYoVbqGsZapaiZdU1Z0zo0u+P13ATy13H6ikTh6Nl5dYSsoUvWeUq6cyfCs8Ni0rvWrzXxv3c5LE6QVOf5skvfZ2KjMFLcrs/ARFnnbsV1f/6++kQPW/j4u3Y6FlSUMuvmatWzkRMLsLM8yh4ocAO+88WqyAYB/nMvO8wUedzymLEHg8RjbG3m7cITPeeo8Vx+Iz/oBAJDO8FNJOMbbeqHKP3VtvUSNZe8gSykrzsuOlvkkgCNVHdEw1hG1KCt+WkT2oJT6PQ3gC5dkhIZxGalFWTFwzsMwrgRsJt0wAljVcnfnisgVKoPgWIxLrRvquNy5WOCZ1HSClcEBoKGeA79ingPyyTSvexhX1uDTFNp9jwPYdI5n9rs2aeslAvX1HLBu2qSUiBf5OHM+zx63t3EPeCbB28UjnHAI1fN2ABAf54C8boTPx/M58C+Ckx1eiD/rugb+rNMpTshE4rrQW9FxQsYXDtwzhcoqB1/pe9ewO4hhBGAOYhgBmIMYRgDmIIYRwKoG6UW/iFS6cma54LNoWXKWhdpCwkGtCAe1ANDcxPZ0mvcZUZYEkzAH+KksB9/JIS5t12auoZwfADifM+chRR3e95VgV8m6F9PcIhAOcWCbSnNAnczpffPSzLP40sABfeo8B9V5JQgugI89l+HrmHccZA8MD6pjHBnjSoXOzZwMcOnKJE9RKfvXsDuIYQRgDmIYAZiDGEYA5iCGEYA5iGEEsLqlJr6HfKayVCE1y83z2kLyuRxnaaJKuQcATL3FJSgzKc6C7H7HtWRLjHBGxxO+TOoad0pm6q0+PfsSi3JWrqWNsy/Nrfwd1tzCZTPIcbYrrpSzJGZZJCOd5iwUALiMIvAQ4cxfHlx+4ucVgYYQfy75MGex0nnOTPWfZUELAEgm+G+gpYf7QQpe5Tk66NnFxdgdxDACMAcxjADMQQwjgGpabuMi8rKIvC4iR0Xkb8v2q0Rkn4icEpEfiojyYGwY65tqgvQ5AB9wzs2WxRteFJGfA/hLlITjHhORfwbwAEpKJ0uSz/kYGqgsx/CVwDYa4RKHwWEOnnM5XRAhrCxh0NLKgeTgsFLS4vF4PPD+6pW+Ck2VMRzTpY6OnzpOts1ZHmP4PJdnRCKcIGisZzXBhgZWPMxkOEgPRZfqteAAujHew9t5SsNMhktSpgp8vaWLy3MmZ/mzTs7qY8w6/o7vfScrT+6+ZXvF64OHn1H3t5hl7yCuxHwxUqT8zwH4AIAfl+2PAvhEVUc0jHVEVTGIiITKgg1jAJ4F0Adg2jk3nwccwBJqiwuF49KzejrRMNYqVTmIc67onNsDoAfAbQCur/YAC4Xj6hstTDHWFxeUxXLOTQN4HsB7ALSI/G4GrQeAPiNmGOuYapY/6ASQd85Ni0gdgLsB/D1KjvJHAB4DcD+AJ5bb19xcHn19w5X7V5YqaGpk28wU+3IyqT+y7drNsv+921kJcWDoNB+7iSWGXZ5nXesbOKCOKYF77zZdwa+tjWeas1meaZ5W1glMTClqlG3Kun557m3xPD5uInVeHWOuyLPz0wkWSdiQ4hn7mBI8Zz3eXyzK2yWSSh9LSv8ub97CTyXxTkW0o7EyOeGUXhmNarJY3QAeFZEQSnecHznnnhKRYwAeE5G/A/AaSuqLhnFFUY1w3CGUFN0X2/tRikcM44rFZtINIwBzEMMIQJyrruz3ohxMZBzAGQAdAPTIcP1h57I2We5ctjvnOpfbyao6yO8OKnLAOXfrqh/4EmDnsja5WOdij1iGEYA5iGEEcLkc5DuX6biXAjuXtclFOZfLEoMYxnrBHrEMIwBzEMMIYNUdRETuEZE3y626D6728WtBRB4WkTERObLA1iYiz4rIyfL/XO24BhGRrSLyvIgcK7dS/0XZvu7O51K2ha+qg5QLHr8F4KMAdqG0Uu6u1RxDjTwC4J5FtgcBPOecuwbAc+XX64ECgK8453YBuB3AF8ufxXo8n/m28JsB7AFwj4jcjlLV+Tedc1cDmEKpLfyCWO07yG0ATjnn+p1zOZRK5e9d5TGsGOfcCwAWN8Lfi1LLMbCOWo+dc8POuVfLPycBvIFSV+i6O59L2Ra+2g6yBcBCibwlW3XXERudc/NNLiMANl7OwawEEelFqWJ7H9bp+dTSFh6EBekXEVfKma+rvLmINAJ4HMCXnauUMVlP51NLW3gQq+0ggwC2Lnh9JbTqjopINwCU/2ex4TVKWcbpcQDfd879pGxet+cDXPy28NV2kP0ArilnF6IAPgXgyVUew8XmSZRajoEqW4/XAiIiKHWBvuGc+8aCX6278xGRThFpKf883xb+Bv6/LRxY6bk451b1H4CPATiB0jPiX6328Wsc+w8ADAPIo/RM+wCAdpSyPScB/BJA2+UeZ5Xn8l6UHp8OAThY/vex9Xg+AG5Cqe37EIAjAP66bN8B4GUApwD8O4DYhe7bSk0MIwAL0g0jAHMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAL4P/reBAlsXKWPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image_path = './work/out/img.png' # 图片路径\n",
    "model_path = './work/out/ssrnet-best' # 模型路径\n",
    "\n",
    "# 加载图像\n",
    "def load_image(image_path):\n",
    "    \"\"\"\n",
    "    功能:\n",
    "        读取图像并转换到输入格式\n",
    "    输入:\n",
    "        image_path - 输入图像路径\n",
    "    输出:\n",
    "        image - 输出图像\n",
    "    \"\"\"\n",
    "    # 读取图像\n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    \n",
    "    # 转换格式\n",
    "    image = image.resize((32, 32), Image.ANTIALIAS) # 调整图像大小\n",
    "    image = np.array(image, dtype=np.float32) # 转换数据格式，数据类型转换为float32\n",
    "\n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    image = (image/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    image = image.transpose((2, 0, 1)).astype(np.float32) # 数据格式从HWC转换为CHW，数据类型转换为float32\n",
    "    \n",
    "    # 增加维度\n",
    "    image = np.expand_dims(image, axis=0) # 增加数据维度\n",
    "    \n",
    "    return image\n",
    "\n",
    "# 预测图像\n",
    "with fluid.dygraph.guard():\n",
    "    # 读取图像\n",
    "    image = load_image(image_path)\n",
    "    image = fluid.dygraph.to_variable(image)\n",
    "    \n",
    "    # 加载模型\n",
    "    model = SSRNet()                               # 加载模型\n",
    "    model_dict, _ = fluid.load_dygraph(model_path) # 加载权重\n",
    "    model.set_dict(model_dict)                     # 设置权重\n",
    "    model.eval()                                   # 设置验证\n",
    "    \n",
    "    # 前向传播\n",
    "    infer_time = time.time()              # 推断开始时间\n",
    "    infer = model(image)\n",
    "    infer_time = time.time() - infer_time # 推断结束时间\n",
    "    \n",
    "    # 显示结果\n",
    "    vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "             'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "             'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "             'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "             'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "             'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "             'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "             'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "             'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "             'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "             'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "             'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "             'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "             'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "             'baby', 'boy', 'girl', 'man', 'woman',\n",
    "             'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "             'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "             'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "             'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "             'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "    vlist.sort() # 字母上升排序\n",
    "    print('infer time: {:f}s, infer value: {}'.format(infer_time, vlist[np.argmax(infer.numpy())]) )\n",
    "    \n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    plt.figure(figsize=(3, 3))     # 设置显示大小\n",
    "    plt.imshow(image)              # 设置显示图像\n",
    "    plt.show()                     # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 1.8.4 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
